Skip to main content

hashiverse_lib/tools/time_provider/
manual_time_provider.rs

1//! # Virtual clock for deterministic tests
2//!
3//! Implements [`crate::tools::time_provider::time_provider::TimeProvider`] as a
4//! manually-advanced clock. Tests set the "current" time explicitly, and every
5//! `sleep` returns immediately — so a simulated network that covers hours or days
6//! of activity runs in a few hundred milliseconds of wall time, deterministically.
7
8use crate::tools::time::{TimeMillis};
9use parking_lot::RwLock;
10use std::cmp::Ordering;
11use std::collections::BinaryHeap;
12use std::future::Future;
13use std::pin::Pin;
14use std::sync::Arc;
15use std::task::{Context, Poll, Waker};
16use std::time::Duration;
17use tokio::sync::Notify;
18use tokio_util::sync::CancellationToken;
19use crate::tools::time_provider::time_provider::TimeProvider;
20
21/// A wake time entry for the manual time provider
22#[derive(Debug)]
23struct ManualTimeProviderWakeTime {
24    time: TimeMillis,
25    waker: Waker,
26}
27
28impl PartialEq for ManualTimeProviderWakeTime {
29    fn eq(&self, other: &Self) -> bool {
30        self.time == other.time
31    }
32}
33
34impl Eq for ManualTimeProviderWakeTime {}
35
36impl PartialOrd for ManualTimeProviderWakeTime {
37    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
38        Some(self.cmp(other))
39    }
40}
41
42impl Ord for ManualTimeProviderWakeTime {
43    fn cmp(&self, other: &Self) -> Ordering {
44        // We want a min-heap, so reverse the comparison
45        other.time.cmp(&self.time)
46    }
47}
48
49/// Implementation of TimeProvider that allows manual control of time
50///
51/// This provider allows you to explicitly set the current time for testing
52/// time-dependent code without relying on the system clock.
53#[derive(Clone)]
54pub struct ManualTimeProvider {
55    current_time: Arc<RwLock<TimeMillis>>,
56    wake_times: Arc<RwLock<BinaryHeap<ManualTimeProviderWakeTime>>>,
57    new_sleepers_notify: Arc<Notify>,
58}
59
60impl Default for ManualTimeProvider {
61    fn default() -> Self {
62        Self::new(TimeMillis::zero())
63    }
64}
65
66impl ManualTimeProvider {
67    /// Create a new ManualTimeProvider with a specified starting time
68    pub fn new(start_time_millis: TimeMillis) -> Self {
69        Self {
70            current_time: Arc::new(RwLock::new(start_time_millis)),
71            wake_times: Arc::new(RwLock::new(BinaryHeap::new())),
72            new_sleepers_notify: Arc::new(Notify::new()),
73        }
74    }
75
76    /// Advance time until there are no more registered sleepers.
77    ///
78    /// It jumps time forward to the next wakeup repeatedly, waking all due tasks,
79    /// and yields to the executor so those tasks can make progress and potentially
80    /// register more sleeps.
81    pub async fn run_all_sleepers_till_done(&self, cancellation_token: &CancellationToken) {
82        while !cancellation_token.is_cancelled() {
83            if self.wake_times.read().is_empty() {
84                tokio::select! {
85                   _ = self.new_sleepers_notify.notified() => {},
86                    _ = cancellation_token.cancelled() => {},
87                }
88            }
89
90            tokio::task::yield_now().await;
91            self.advance_time_until_next_sleeper().await;
92        }
93    }
94
95    /// Advance time to the earlier of: the next scheduled wake time or the current time plus max_advance_ms
96    ///
97    /// Returns the amount of time (in milliseconds) that was actually advanced and the remaining number of sleeping taks
98    pub async fn advance_time_until_next_sleeper(&self) {
99        let mut current = self.current_time.write();
100        let mut wake_times = self.wake_times.write();
101
102        let new_time = match wake_times.peek() {
103            Some(wake_time) => wake_time.time,
104            None => *current,
105        };
106
107        // Set the new current time
108        *current = new_time;
109
110        // Collect wakers that need to be awakened
111        let mut wakers_to_wake = Vec::new();
112
113        while let Some(wake_time) = wake_times.peek() {
114            if wake_time.time <= new_time {
115                if let Some(entry) = wake_times.pop() {
116                    wakers_to_wake.push(entry.waker);
117                }
118            }
119            else {
120                break;
121            }
122        }
123
124        // Drop the locks before waking to avoid potential deadlocks
125        drop(current);
126        drop(wake_times);
127
128        // Now wake all the sleepers
129        for waker in wakers_to_wake {
130            waker.wake();
131        }
132    }
133
134    /// Register a waker to be notified when time reaches a certain point
135    fn register_wake_time(&self, wake_time: TimeMillis, waker: Waker) {
136        let mut wake_times = self.wake_times.write();
137        wake_times.push(ManualTimeProviderWakeTime { time: wake_time, waker });
138
139        // Wake the time driver (if it is waiting for more sleepers)
140        self.new_sleepers_notify.notify_one();
141    }
142}
143
144/// Future returned by ManualTimeProvider::sleep
145pub struct ManualTimeProviderSleep {
146    provider: ManualTimeProvider,
147    wake_time: TimeMillis,
148}
149
150impl Future for ManualTimeProviderSleep {
151    type Output = ();
152
153    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
154        // Check if the current time has reached or passed the wake time
155        let current_time = self.provider.current_time_millis();
156        if current_time >= self.wake_time {
157            return Poll::Ready(());
158        }
159
160        // Register waker so we're notified when time advances.
161        let waker = cx.waker().clone();
162        self.provider.register_wake_time(self.wake_time, waker);
163
164        Poll::Pending
165    }
166}
167
168impl TimeProvider for ManualTimeProvider {
169    fn current_time_millis(&self) -> TimeMillis {
170        *self.current_time.read()
171    }
172
173    fn sleep(&self, duration: Duration) -> Pin<Box<dyn Future<Output = ()> + Send>> {
174        let current_time = self.current_time_millis();
175        let wake_time = current_time + duration;
176
177        Box::pin(ManualTimeProviderSleep {
178            provider: self.clone(), // Note that the provider may be a different copy, but the state inside is shared...
179            wake_time,
180        })
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use crate::tools::time::{MILLIS_IN_SECOND, TimeMillis};
187    use log::info;
188    use std::sync::Arc;
189    use tokio_util::sync::CancellationToken;
190    use crate::tools::time_provider::manual_time_provider::ManualTimeProvider;
191    use crate::tools::time_provider::time_provider::TimeProvider;
192
193    #[tokio::test]
194    async fn generic_test() {
195        let time_provider = Arc::new(ManualTimeProvider::new(TimeMillis::zero()));
196        // configure_logging_with_time_provider("trace", time_provider.clone());
197
198        let cancellation_token = CancellationToken::new();
199
200        tokio::join!(
201            async {
202                info!("Thread 1 start");
203                for _ in 0..10 {
204                    info!("Thread 1 tick");
205                    time_provider.sleep_millis(MILLIS_IN_SECOND.const_mul(1)).await;
206                }
207                info!("Thread 1 end");
208            },
209            async {
210                info!("Thread 2 start");
211                for _ in 0..10 {
212                    tokio::task::yield_now().await;
213                    tokio::task::yield_now().await;
214                    tokio::task::yield_now().await;
215                    info!("Thread 2 tick");
216                    tokio::task::yield_now().await;
217                    tokio::task::yield_now().await;
218                    tokio::task::yield_now().await;
219                    time_provider.sleep_millis(MILLIS_IN_SECOND.const_mul(1)).await;
220                    tokio::task::yield_now().await;
221                    tokio::task::yield_now().await;
222                    tokio::task::yield_now().await;
223                }
224                cancellation_token.cancel();
225                info!("Thread 2 end");
226            },
227            async {
228                info!("Time driver start");
229                time_provider.run_all_sleepers_till_done(&cancellation_token).await;
230                info!("Time driver end");
231            },
232        );
233    }
234}