diff --git a/Sources/NIO/Embedded.swift b/Sources/NIO/Embedded.swift index 42af3a0d..a0e6c915 100644 --- a/Sources/NIO/Embedded.swift +++ b/Sources/NIO/Embedded.swift @@ -57,8 +57,6 @@ public class EmbeddedEventLoop: EventLoop { return true } - var tasks = CircularBuffer<() -> Void>(initialRingCapacity: 2) - public init() { } public func scheduleTask(in: TimeAmount, _ task: @escaping () throws -> T) -> Scheduled { @@ -83,14 +81,12 @@ public class EmbeddedEventLoop: EventLoop { // at which point we run everything that's been submitted. Anything newly submitted // either gets on that train if it's still moving or waits until the next call to run(). public func execute(_ task: @escaping () -> Void) { - tasks.append(task) + _ = self.scheduleTask(in: .nanoseconds(0), task) } public func run() { - // Execute all tasks that are currently enqueued. - while !tasks.isEmpty { - tasks.removeFirst()() - } + // Execute all tasks that are currently enqueued to be executed *now*. + self.advanceTime(by: .nanoseconds(0)) } /// Runs the event loop and moves "time" forward by the given amount, running any scheduled @@ -98,22 +94,26 @@ public class EmbeddedEventLoop: EventLoop { public func advanceTime(by: TimeAmount) { let newTime = self.now + UInt64(by.nanoseconds) - // First, run the event loop to dispatch any current work. - self.run() - while let nextTask = self.scheduledTasks.peek() { guard nextTask.readyTime <= newTime else { break } - // Set the time correctly before we call into user code, then - // call in. Once we've done that, spin the event loop in case any - // work was scheduled by the delayed task. - _ = self.scheduledTasks.pop() - self.now = nextTask.readyTime - nextTask.task() + // Now we want to grab all tasks that are ready to execute at the same + // time as the first. + var tasks = Array() + while let candidateTask = self.scheduledTasks.peek(), candidateTask.readyTime == nextTask.readyTime { + tasks.append(candidateTask) + _ = self.scheduledTasks.pop() + } - self.run() + // Set the time correctly before we call into user code, then + // call in for all tasks. + self.now = nextTask.readyTime + + for task in tasks { + task.task() + } } // Finally ensure we got the time right. @@ -132,7 +132,6 @@ public class EmbeddedEventLoop: EventLoop { } deinit { - precondition(tasks.isEmpty, "Embedded event loop freed with unexecuted tasks!") precondition(scheduledTasks.isEmpty, "Embedded event loop freed with unexecuted scheduled tasks!") } } diff --git a/Tests/NIOTests/EmbeddedEventLoopTest+XCTest.swift b/Tests/NIOTests/EmbeddedEventLoopTest+XCTest.swift index 0f945778..db6144a9 100644 --- a/Tests/NIOTests/EmbeddedEventLoopTest+XCTest.swift +++ b/Tests/NIOTests/EmbeddedEventLoopTest+XCTest.swift @@ -39,6 +39,7 @@ extension EmbeddedEventLoopTest { ("testCancellingScheduledTasks", testCancellingScheduledTasks), ("testScheduledTasksFuturesFire", testScheduledTasksFuturesFire), ("testScheduledTasksFuturesError", testScheduledTasksFuturesError), + ("testTaskOrdering", testTaskOrdering), ] } } diff --git a/Tests/NIOTests/EmbeddedEventLoopTest.swift b/Tests/NIOTests/EmbeddedEventLoopTest.swift index 989c027a..a161f722 100644 --- a/Tests/NIOTests/EmbeddedEventLoopTest.swift +++ b/Tests/NIOTests/EmbeddedEventLoopTest.swift @@ -224,4 +224,91 @@ public class EmbeddedEventLoopTest: XCTestCase { loop.advanceTime(by: .nanoseconds(1)) XCTAssertTrue(fired) } + + func testTaskOrdering() { + // This test validates that the ordering of task firing on EmbeddedEventLoop via + // advanceTime(by:) is the same as on MultiThreadedEventLoopGroup: specifically, that tasks run via + // schedule that expire "now" all run at the same time, and that any work they schedule is run + // after all such tasks expire. + let loop = EmbeddedEventLoop() + var firstScheduled: Scheduled? = nil + var secondScheduled: Scheduled? = nil + var orderingCounter = 0 + + // Here's the setup. First, we'll set up two scheduled tasks to fire in 5 nanoseconds. Each of these + // will attempt to cancel the other, whichever fires first. Additionally, each will execute{} a single + // callback. Then we'll execute {} one other callback. Finally we'll schedule a task for 10ns, before + // we advance time. The ordering should be as follows: + // + // 1. The task executed by execute {} from this function. + // 2. One of the first scheduled tasks. + // 3. The other first scheduled task (note that the cancellation will fail). + // 4. One of the execute {} callbacks from a scheduled task. + // 5. The other execute {} callbacks from the scheduled task. + // 6. The 10ns task. + // + // To validate the ordering, we'll use a counter. + + func delayedExecute() { + // The acceptable value for the delayed execute callbacks is 3 or 4. + XCTAssertTrue(orderingCounter == 3 || orderingCounter == 4, "Invalid counter value \(orderingCounter)") + orderingCounter += 1 + } + + firstScheduled = loop.scheduleTask(in: .nanoseconds(5)) { + firstScheduled = nil + + let expected: Int + if let partner = secondScheduled { + // Ok, this callback fired first. Cancel the other, then set the expected current + // counter value to 1. + partner.cancel() + expected = 1 + } else { + // This callback fired second. + expected = 2 + } + + XCTAssertEqual(orderingCounter, expected) + orderingCounter = expected + 1 + loop.execute(delayedExecute) + } + + secondScheduled = loop.scheduleTask(in: .nanoseconds(5)) { + secondScheduled = nil + + let expected: Int + if let partner = firstScheduled { + // Ok, this callback fired first. Cancel the other, then set the expected current + // counter value to 1. + partner.cancel() + expected = 1 + } else { + // This callback fired second. + expected = 2 + } + + XCTAssertEqual(orderingCounter, expected) + orderingCounter = expected + 1 + loop.execute(delayedExecute) + } + + // Ok, now we set one more task to execute. + loop.execute { + XCTAssertEqual(orderingCounter, 0) + orderingCounter = 1 + } + + // Finally schedule a task for 10ns. + loop.scheduleTask(in: .nanoseconds(10)) { + XCTAssertEqual(orderingCounter, 5) + orderingCounter = 6 + } + + // Now we advance time by 10ns. + loop.advanceTime(by: .nanoseconds(10)) + + // Now the final value should be 6. + XCTAssertEqual(orderingCounter, 6) + } }