NIOThrowingAsyncSequenceProducer throws when cancelled (#2415)

* NIOThrowingAsyncSequenceProducer throws when cancelled

* PR review
This commit is contained in:
Fabian Fett 2023-04-28 17:21:15 +02:00 committed by GitHub
parent 5f8b0647e4
commit d1690f8541
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 75 additions and 9 deletions

View File

@ -511,6 +511,21 @@ extension NIOThrowingAsyncSequenceProducer {
return nil return nil
} }
case .returnCancellationError:
self._lock.unlock()
// We have deprecated the generic Failure type in the public API and Failure should
// now be `Swift.Error`. However, if users have not migrated to the new API they could
// still use a custom generic Error type and this cast might fail.
// In addition, we use `NIOThrowingAsyncSequenceProducer` in the implementation of the
// non-throwing variant `NIOAsyncSequenceProducer` where `Failure` will be `Never` and
// this cast will fail as well.
// Everything is marked @inlinable and the Failure type is known at compile time,
// therefore this cast should be optimised away in release build.
if let error = CancellationError() as? Failure {
throw error
}
return nil
case .returnNil: case .returnNil:
self._lock.unlock() self._lock.unlock()
return nil return nil
@ -603,6 +618,9 @@ extension NIOThrowingAsyncSequenceProducer {
failure: Failure? failure: Failure?
) )
/// The state once a call to next has been cancelled. Cancel the source when entering this state.
case cancelled(iteratorInitialized: Bool)
/// The state once there can be no outstanding demand. This can happen if: /// The state once there can be no outstanding demand. This can happen if:
/// 1. The ``NIOThrowingAsyncSequenceProducer/AsyncIterator`` was deinited /// 1. The ``NIOThrowingAsyncSequenceProducer/AsyncIterator`` was deinited
/// 2. The underlying source finished and all buffered elements have been consumed /// 2. The underlying source finished and all buffered elements have been consumed
@ -644,7 +662,8 @@ extension NIOThrowingAsyncSequenceProducer {
switch self._state { switch self._state {
case .initial(_, iteratorInitialized: false), case .initial(_, iteratorInitialized: false),
.streaming(_, _, _, _, iteratorInitialized: false), .streaming(_, _, _, _, iteratorInitialized: false),
.sourceFinished(_, iteratorInitialized: false, _): .sourceFinished(_, iteratorInitialized: false, _),
.cancelled(iteratorInitialized: false):
// No iterator was created so we can transition to finished right away. // No iterator was created so we can transition to finished right away.
self._state = .finished(iteratorInitialized: false) self._state = .finished(iteratorInitialized: false)
@ -652,7 +671,8 @@ extension NIOThrowingAsyncSequenceProducer {
case .initial(_, iteratorInitialized: true), case .initial(_, iteratorInitialized: true),
.streaming(_, _, _, _, iteratorInitialized: true), .streaming(_, _, _, _, iteratorInitialized: true),
.sourceFinished(_, iteratorInitialized: true, _): .sourceFinished(_, iteratorInitialized: true, _),
.cancelled(iteratorInitialized: true):
// An iterator was created and we deinited the sequence. // An iterator was created and we deinited the sequence.
// This is an expected pattern and we just continue on normal. // This is an expected pattern and we just continue on normal.
return .none return .none
@ -673,6 +693,7 @@ extension NIOThrowingAsyncSequenceProducer {
case .initial(_, iteratorInitialized: true), case .initial(_, iteratorInitialized: true),
.streaming(_, _, _, _, iteratorInitialized: true), .streaming(_, _, _, _, iteratorInitialized: true),
.sourceFinished(_, iteratorInitialized: true, _), .sourceFinished(_, iteratorInitialized: true, _),
.cancelled(iteratorInitialized: true),
.finished(iteratorInitialized: true): .finished(iteratorInitialized: true):
// Our sequence is a unicast sequence and does not support multiple AsyncIterator's // Our sequence is a unicast sequence and does not support multiple AsyncIterator's
fatalError("NIOThrowingAsyncSequenceProducer allows only a single AsyncIterator to be created") fatalError("NIOThrowingAsyncSequenceProducer allows only a single AsyncIterator to be created")
@ -694,6 +715,10 @@ extension NIOThrowingAsyncSequenceProducer {
iteratorInitialized: true iteratorInitialized: true
) )
case .cancelled(iteratorInitialized: false):
// An iterator needs to be initialized before we can be cancelled.
preconditionFailure("Internal inconsistency")
case .sourceFinished(let buffer, false, let failure): case .sourceFinished(let buffer, false, let failure):
// The first and only iterator was initialized. // The first and only iterator was initialized.
self._state = .sourceFinished( self._state = .sourceFinished(
@ -727,13 +752,15 @@ extension NIOThrowingAsyncSequenceProducer {
switch self._state { switch self._state {
case .initial(_, iteratorInitialized: false), case .initial(_, iteratorInitialized: false),
.streaming(_, _, _, _, iteratorInitialized: false), .streaming(_, _, _, _, iteratorInitialized: false),
.sourceFinished(_, iteratorInitialized: false, _): .sourceFinished(_, iteratorInitialized: false, _),
.cancelled(iteratorInitialized: false):
// An iterator needs to be initialized before it can be deinitialized. // An iterator needs to be initialized before it can be deinitialized.
preconditionFailure("Internal inconsistency") preconditionFailure("Internal inconsistency")
case .initial(_, iteratorInitialized: true), case .initial(_, iteratorInitialized: true),
.streaming(_, _, _, _, iteratorInitialized: true), .streaming(_, _, _, _, iteratorInitialized: true),
.sourceFinished(_, iteratorInitialized: true, _): .sourceFinished(_, iteratorInitialized: true, _),
.cancelled(iteratorInitialized: true):
// An iterator was created and deinited. Since we only support // An iterator was created and deinited. Since we only support
// a single iterator we can now transition to finish and inform the delegate. // a single iterator we can now transition to finish and inform the delegate.
self._state = .finished(iteratorInitialized: true) self._state = .finished(iteratorInitialized: true)
@ -861,7 +888,7 @@ extension NIOThrowingAsyncSequenceProducer {
return .init(shouldProduceMore: shouldProduceMore) return .init(shouldProduceMore: shouldProduceMore)
case .sourceFinished, .finished: case .cancelled, .sourceFinished, .finished:
// If the source has finished we are dropping the elements. // If the source has finished we are dropping the elements.
return .returnDropped return .returnDropped
@ -913,7 +940,7 @@ extension NIOThrowingAsyncSequenceProducer {
return .none return .none
case .sourceFinished, .finished: case .cancelled, .sourceFinished, .finished:
// If the source has finished, finishing again has no effect. // If the source has finished, finishing again has no effect.
return .none return .none
@ -968,11 +995,14 @@ extension NIOThrowingAsyncSequenceProducer {
return .resumeContinuationWithCancellationErrorAndCallDidTerminate(continuation) return .resumeContinuationWithCancellationErrorAndCallDidTerminate(continuation)
case .streaming(_, _, continuation: .none, _, let iteratorInitialized): case .streaming(_, _, continuation: .none, _, let iteratorInitialized):
self._state = .finished(iteratorInitialized: iteratorInitialized) // We may have elements in the buffer, which is why we have no continuation
// waiting. We must store the cancellation error to hand it out on the next
// next() call.
self._state = .cancelled(iteratorInitialized: iteratorInitialized)
return .callDidTerminate return .callDidTerminate
case .sourceFinished, .finished: case .cancelled, .sourceFinished, .finished:
// If the source has finished, finishing again has no effect. // If the source has finished, finishing again has no effect.
return .none return .none
@ -992,6 +1022,8 @@ extension NIOThrowingAsyncSequenceProducer {
/// Indicates that the `Failure` should be returned to the caller and /// Indicates that the `Failure` should be returned to the caller and
/// that ``NIOAsyncSequenceProducerDelegate/didTerminate()`` should be called. /// that ``NIOAsyncSequenceProducerDelegate/didTerminate()`` should be called.
case returnFailureAndCallDidTerminate(Failure?) case returnFailureAndCallDidTerminate(Failure?)
/// Indicates that the next call to AsyncSequence got cancelled
case returnCancellationError
/// Indicates that the `nil` should be returned to the caller. /// Indicates that the `nil` should be returned to the caller.
case returnNil case returnNil
/// Indicates that the `Task` of the caller should be suspended. /// Indicates that the `Task` of the caller should be suspended.
@ -1075,6 +1107,10 @@ extension NIOThrowingAsyncSequenceProducer {
return .returnFailureAndCallDidTerminate(failure) return .returnFailureAndCallDidTerminate(failure)
} }
case .cancelled(let iteratorInitialized):
self._state = .finished(iteratorInitialized: iteratorInitialized)
return .returnCancellationError
case .finished: case .finished:
return .returnNil return .returnNil
@ -1119,7 +1155,7 @@ extension NIOThrowingAsyncSequenceProducer {
return .none return .none
} }
case .streaming(_, _, .some(_), _, _), .sourceFinished, .finished: case .streaming(_, _, .some(_), _, _), .sourceFinished, .finished, .cancelled:
preconditionFailure("This should have already been handled by `next()`") preconditionFailure("This should have already been handled by `next()`")
case .modifying: case .modifying:

View File

@ -743,6 +743,36 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase {
XCTAssertEqualWithoutAutoclosure(await self.delegate.events.prefix(1).collect(), [.didTerminate]) XCTAssertEqualWithoutAutoclosure(await self.delegate.events.prefix(1).collect(), [.didTerminate])
} }
func testIteratorThrows_whenCancelled() async {
_ = self.source.yield(contentsOf: Array(0..<100))
await withThrowingTaskGroup(of: Void.self) { group in
group.addTask {
var counter = 0
guard let sequence = self.sequence else {
return XCTFail("Expected to have an AsyncSequence")
}
do {
for try await next in sequence {
XCTAssertEqual(next, counter)
counter += 1
}
XCTFail("Expected that this throws")
} catch is CancellationError {
// expected
} catch {
XCTFail("Unexpected error: \(error)")
}
XCTAssertLessThan(counter, 100)
}
group.cancelAll()
}
XCTAssertEqualWithoutAutoclosure(await self.delegate.events.prefix(1).collect(), [.didTerminate])
}
} }
// This is needed until async let is supported to be used in autoclosures // This is needed until async let is supported to be used in autoclosures