Skip to main content

hashiverse_lib/transport/
mem_transport.rs

1//! # In-memory transport for tests
2//!
3//! A fully synchronous in-process implementation of
4//! [`crate::transport::transport::TransportFactory`] and
5//! [`crate::transport::transport::TransportServer`]: every "server" registers itself in
6//! a shared registry keyed by id, and every "client" request is just a channel send into
7//! the matching server's request queue.
8//!
9//! This is what makes the integration-test harness fast and deterministic. A virtual
10//! network of dozens of servers + clients runs inside a single test binary, with no
11//! sockets, no TLS negotiation, no PoW relaxation fudge, and no flaky wall-clock
12//! ordering. Swap `MemTransportFactory` for the HTTPS factory and the same protocol code
13//! runs on the real network.
14
15use crate::tools::types::Id;
16use crate::transport::ddos::ddos::{DdosConnectionGuard, DdosProtection};
17use crate::transport::transport::{IncomingRequest, ServerState, TransportFactory, TransportServer};
18use anyhow::{Result, anyhow};
19use bytes::Bytes;
20use log::info;
21use parking_lot::RwLock;
22use std::collections::HashMap;
23use std::sync::Arc;
24use tokio::sync::{mpsc, oneshot};
25use tokio_util::sync::CancellationToken;
26use tokio_util::task::TaskTracker;
27use crate::transport::bootstrap_provider::bootstrap_provider::BootstrapProvider;
28use crate::transport::bootstrap_provider::manual_bootstrap_provider::ManualBootstrapProvider;
29use crate::transport::ddos::noop_ddos::NoopDdosProtection;
30
31#[derive(Debug)]
32struct RpcMessage {
33    caller_address: String,
34    bytes: Bytes,
35    response_tx: oneshot::Sender<Result<Bytes>>,
36}
37
38struct ServerEntry {
39    command_tx: mpsc::Sender<RpcMessage>,
40}
41
42struct ServerManager {
43    servers: Arc<RwLock<HashMap<u16, Arc<ServerEntry>>>>,
44}
45
46impl ServerManager {
47    pub fn new() -> Self {
48        ServerManager {
49            servers: Arc::new(RwLock::new(HashMap::new())),
50        }
51    }
52    pub async fn remove_server(&self, port: u16) {
53        let mut servers_locked = self.servers.write();
54        servers_locked.remove(&port);
55    }
56}
57
58/// An entirely in-process [`TransportServer`] used by the integration test harness.
59///
60/// Servers created by `MemTransportFactory` share a process-wide registry keyed by port;
61/// "sending a request" from one client to one server becomes a channel send on the registry.
62/// There is no serialization to sockets, no DNS, no kernel — which makes this both
63/// dramatically faster than a real network and fully deterministic when paired with a virtual
64/// [`crate::tools::time_provider::time_provider::TimeProvider`]. Port `0` is translated to a
65/// freshly-allocated port number, mirroring the semantics of a real OS bind.
66///
67/// Not for production use: there is nothing here that crosses a process or host boundary.
68pub struct MemTransportServer {
69    port: u16,
70    address: String,
71    server_manager: Arc<ServerManager>,
72    command_rx: Arc<RwLock<Option<mpsc::Receiver<RpcMessage>>>>,
73    state: Arc<RwLock<ServerState>>,
74    ddos_protection: Arc<dyn DdosProtection>,
75}
76
77#[async_trait::async_trait]
78impl TransportServer for MemTransportServer {
79    fn get_address(&self) -> &String {
80        &self.address
81    }
82
83    async fn listen(&self, cancellation_token: CancellationToken, handler: mpsc::Sender<IncomingRequest>) -> Result<()> {
84        async fn process_connection(_cancellation_token: CancellationToken, handler: mpsc::Sender<IncomingRequest>, message: RpcMessage, ddos_protection: Arc<dyn DdosProtection>) -> anyhow::Result<()> {
85            // trace!("accepted connection");
86            // scopeguard::defer! { trace!("dropped connection"); }
87            // trace!("received packet={:?}", message.bytes);
88            let ddos_connection_guard = match DdosConnectionGuard::try_new(ddos_protection, message.caller_address.as_str()) {
89                Some(guard) => Arc::new(guard),
90                None => return Ok(()),
91            };
92            let caller_address = ddos_connection_guard.ip().to_string();
93            let (reply_tx, reply_rx) = oneshot::channel();
94            handler.send(IncomingRequest::new(caller_address, message.bytes, reply_tx, ddos_connection_guard)).await?;
95            let response = reply_rx.await?;
96            let _ = message.response_tx.send(Ok(response.to_bytes()));
97
98            Ok(())
99        }
100
101        // Check that we can transition to listening
102        {
103            let mut state = self.state.write();
104            match *state {
105                ServerState::Listening => {
106                    anyhow::bail!("server is already listening");
107                }
108                ServerState::Shutdown => {
109                    anyhow::bail!("server has been shut down");
110                }
111                ServerState::Created => {
112                    *state = ServerState::Listening;
113                }
114            }
115        }
116
117        let task_tracker = TaskTracker::new();
118
119        info!("listening on address {}", self.address);
120
121        // Take ownership of the receiver.  If there's no receiver, we can't listen.  Should never happen!
122        let mut receiver = match self.command_rx.write().take() {
123            Some(r) => r,
124            None => {
125                return Err(anyhow!("no receiver available on address {}", self.address));
126            }
127        };
128
129        loop {
130            tokio::select! {
131                _ = cancellation_token.cancelled() => {
132                    break;
133                }
134
135                Some(msg) = receiver.recv() => {
136                    task_tracker.spawn(
137                        process_connection(cancellation_token.clone(), handler.clone(), msg, self.ddos_protection.clone())
138                    );
139                }
140            }
141        }
142
143        info!("stopped listening on port {}", self.address);
144        self.server_manager.remove_server(self.port).await;
145
146        // Wait for existing connections to complete
147        info!("waiting for open connections to complete");
148        task_tracker.close();
149        task_tracker.wait().await;
150
151        // Notify the "shutdown" coroutine that we have successfully shutdown
152        info!("all open connections complete");
153        *self.state.write() = ServerState::Shutdown;
154
155        Ok(())
156    }
157}
158
159#[derive(Clone)]
160pub struct MemTransportFactory {
161    server_manager: Arc<ServerManager>,
162    ddos_protection: Arc<dyn DdosProtection>,
163    bootstrap_provider: Arc<dyn BootstrapProvider>,
164}
165
166impl MemTransportFactory {
167    pub fn new(ddos_protection: Arc<dyn DdosProtection>, bootstrap_provider: Arc<dyn BootstrapProvider>) -> Self {
168        Self {
169            server_manager: Arc::new(ServerManager::new()),
170            ddos_protection,
171            bootstrap_provider,
172        }
173    }
174
175    pub fn default() -> Arc<Self> {
176        Arc::new(Self::new(NoopDdosProtection::default(), ManualBootstrapProvider::new_mem_multiple()))
177    }
178}
179
180#[async_trait::async_trait]
181impl TransportFactory for MemTransportFactory {
182    async fn get_bootstrap_addresses(&self) -> Vec<String> {
183        self.bootstrap_provider.get_bootstrap_addresses().await
184    }
185
186    async fn create_server(&self, _base_path: &str, port: u16, force_local_network: bool) -> anyhow::Result<Arc<dyn TransportServer>> {
187        if !force_local_network {
188            return Err(anyhow!("only local network is supported"));
189        }
190
191        let mut servers_locked = self.server_manager.servers.write();
192
193        if servers_locked.contains_key(&port) {
194            return Err(anyhow!("server already exists on port {}", port));
195        }
196
197        // If they have requested port 0, pick the first available empty slot
198        let bound_port = match port {
199            0 => {
200                servers_locked.keys().max().unwrap_or(&0u16) + 1
201            }
202            _ => port
203        };
204
205        let address = format!("{}", bound_port);
206
207        // Create channels for communication.  Buffer sized generously so bursts of
208        // concurrent in-memory RPCs don't trip capacity limits; backpressure is still
209        // applied via awaited `send` below, which is closer to the behaviour of a real
210        // TCP socket than `try_send`'s fail-fast.
211        let (tx, rx) = mpsc::channel::<RpcMessage>(256);
212
213        // Create the server
214        let mem_transport_server = Arc::new(MemTransportServer {
215            port: bound_port,
216            address,
217            server_manager: self.server_manager.clone(),
218            command_rx: Arc::new(RwLock::new(Some(rx))),
219            state: Arc::new(RwLock::new(ServerState::Created)),
220            ddos_protection: self.ddos_protection.clone(),
221        });
222
223        // Store the server and its sender in the map
224        servers_locked.insert(bound_port, Arc::new(ServerEntry { command_tx: tx }));
225
226        Ok(mem_transport_server)
227    }
228
229    async fn rpc(&self, address: &str, bytes: Bytes) -> Result<Bytes> {
230        let port: u16 = address.parse()?;
231
232        let server_entry = {
233            let servers = self.server_manager.servers.read();
234            let server_entry = servers.get(&port).ok_or_else(|| anyhow::anyhow!("no server found with port {}", port))?;
235            server_entry.clone()
236        };
237
238        // trace!("connected to: {:?}", address);
239        // defer! { trace!("disconnected from: {:?}", &address); }
240
241        // Create a oneshot channel for the response
242        let (response_tx, response_rx) = oneshot::channel();
243
244        // Create the message
245        let message = RpcMessage { caller_address: format!("mem:{}", Id::random()), bytes, response_tx };
246
247        // Send the message to the server using the sender from the server entry.
248        // Awaited `send` applies backpressure if the receiver is saturated, rather than
249        // dropping the request — mirrors how a real TCP transport would behave.
250        server_entry.command_tx.send(message).await.map_err(|e| anyhow::anyhow!("failed to send request: {}", e))?;
251
252        // Wait for the response
253        response_rx.await.map_err(|_| anyhow::anyhow!("server disconnected before responding"))?
254    }
255}
256
257
258#[cfg(test)]
259mod tests {
260    use crate::transport::mem_transport::MemTransportFactory;
261    use crate::transport::bootstrap_provider::manual_bootstrap_provider::ManualBootstrapProvider;
262    use crate::transport::ddos::noop_ddos::NoopDdosProtection;
263    use std::sync::Arc;
264
265    #[tokio::test]
266    async fn rpc_test() -> anyhow::Result<()> {
267        let factory: Arc<dyn crate::transport::transport::TransportFactory> = Arc::new(MemTransportFactory::new(NoopDdosProtection::default(), ManualBootstrapProvider::default()));
268        crate::transport::transport::tests::rpc_test(factory).await
269    }
270
271    #[tokio::test]
272    async fn bind_port_zero_test() -> anyhow::Result<()> {
273        let factory: Arc<dyn crate::transport::transport::TransportFactory> = Arc::new(MemTransportFactory::new(NoopDdosProtection::default(), ManualBootstrapProvider::default()));
274        crate::transport::transport::tests::bind_port_zero_test(factory).await
275    }
276}