hashiverse_lib/transport/
mem_transport.rs1use crate::tools::types::Id;
16use crate::transport::ddos::ddos::{DdosConnectionGuard, DdosProtection};
17use crate::transport::transport::{IncomingRequest, ServerState, TransportFactory, TransportServer};
18use anyhow::{Result, anyhow};
19use bytes::Bytes;
20use log::info;
21use parking_lot::RwLock;
22use std::collections::HashMap;
23use std::sync::Arc;
24use tokio::sync::{mpsc, oneshot};
25use tokio_util::sync::CancellationToken;
26use tokio_util::task::TaskTracker;
27use crate::transport::bootstrap_provider::bootstrap_provider::BootstrapProvider;
28use crate::transport::bootstrap_provider::manual_bootstrap_provider::ManualBootstrapProvider;
29use crate::transport::ddos::noop_ddos::NoopDdosProtection;
30
31#[derive(Debug)]
32struct RpcMessage {
33 caller_address: String,
34 bytes: Bytes,
35 response_tx: oneshot::Sender<Result<Bytes>>,
36}
37
38struct ServerEntry {
39 command_tx: mpsc::Sender<RpcMessage>,
40}
41
42struct ServerManager {
43 servers: Arc<RwLock<HashMap<u16, Arc<ServerEntry>>>>,
44}
45
46impl ServerManager {
47 pub fn new() -> Self {
48 ServerManager {
49 servers: Arc::new(RwLock::new(HashMap::new())),
50 }
51 }
52 pub async fn remove_server(&self, port: u16) {
53 let mut servers_locked = self.servers.write();
54 servers_locked.remove(&port);
55 }
56}
57
58pub struct MemTransportServer {
69 port: u16,
70 address: String,
71 server_manager: Arc<ServerManager>,
72 command_rx: Arc<RwLock<Option<mpsc::Receiver<RpcMessage>>>>,
73 state: Arc<RwLock<ServerState>>,
74 ddos_protection: Arc<dyn DdosProtection>,
75}
76
77#[async_trait::async_trait]
78impl TransportServer for MemTransportServer {
79 fn get_address(&self) -> &String {
80 &self.address
81 }
82
83 async fn listen(&self, cancellation_token: CancellationToken, handler: mpsc::Sender<IncomingRequest>) -> Result<()> {
84 async fn process_connection(_cancellation_token: CancellationToken, handler: mpsc::Sender<IncomingRequest>, message: RpcMessage, ddos_protection: Arc<dyn DdosProtection>) -> anyhow::Result<()> {
85 let ddos_connection_guard = match DdosConnectionGuard::try_new(ddos_protection, message.caller_address.as_str()) {
89 Some(guard) => Arc::new(guard),
90 None => return Ok(()),
91 };
92 let caller_address = ddos_connection_guard.ip().to_string();
93 let (reply_tx, reply_rx) = oneshot::channel();
94 handler.send(IncomingRequest::new(caller_address, message.bytes, reply_tx, ddos_connection_guard)).await?;
95 let response = reply_rx.await?;
96 let _ = message.response_tx.send(Ok(response.to_bytes()));
97
98 Ok(())
99 }
100
101 {
103 let mut state = self.state.write();
104 match *state {
105 ServerState::Listening => {
106 anyhow::bail!("server is already listening");
107 }
108 ServerState::Shutdown => {
109 anyhow::bail!("server has been shut down");
110 }
111 ServerState::Created => {
112 *state = ServerState::Listening;
113 }
114 }
115 }
116
117 let task_tracker = TaskTracker::new();
118
119 info!("listening on address {}", self.address);
120
121 let mut receiver = match self.command_rx.write().take() {
123 Some(r) => r,
124 None => {
125 return Err(anyhow!("no receiver available on address {}", self.address));
126 }
127 };
128
129 loop {
130 tokio::select! {
131 _ = cancellation_token.cancelled() => {
132 break;
133 }
134
135 Some(msg) = receiver.recv() => {
136 task_tracker.spawn(
137 process_connection(cancellation_token.clone(), handler.clone(), msg, self.ddos_protection.clone())
138 );
139 }
140 }
141 }
142
143 info!("stopped listening on port {}", self.address);
144 self.server_manager.remove_server(self.port).await;
145
146 info!("waiting for open connections to complete");
148 task_tracker.close();
149 task_tracker.wait().await;
150
151 info!("all open connections complete");
153 *self.state.write() = ServerState::Shutdown;
154
155 Ok(())
156 }
157}
158
159#[derive(Clone)]
160pub struct MemTransportFactory {
161 server_manager: Arc<ServerManager>,
162 ddos_protection: Arc<dyn DdosProtection>,
163 bootstrap_provider: Arc<dyn BootstrapProvider>,
164}
165
166impl MemTransportFactory {
167 pub fn new(ddos_protection: Arc<dyn DdosProtection>, bootstrap_provider: Arc<dyn BootstrapProvider>) -> Self {
168 Self {
169 server_manager: Arc::new(ServerManager::new()),
170 ddos_protection,
171 bootstrap_provider,
172 }
173 }
174
175 pub fn default() -> Arc<Self> {
176 Arc::new(Self::new(NoopDdosProtection::default(), ManualBootstrapProvider::new_mem_multiple()))
177 }
178}
179
180#[async_trait::async_trait]
181impl TransportFactory for MemTransportFactory {
182 async fn get_bootstrap_addresses(&self) -> Vec<String> {
183 self.bootstrap_provider.get_bootstrap_addresses().await
184 }
185
186 async fn create_server(&self, _base_path: &str, port: u16, force_local_network: bool) -> anyhow::Result<Arc<dyn TransportServer>> {
187 if !force_local_network {
188 return Err(anyhow!("only local network is supported"));
189 }
190
191 let mut servers_locked = self.server_manager.servers.write();
192
193 if servers_locked.contains_key(&port) {
194 return Err(anyhow!("server already exists on port {}", port));
195 }
196
197 let bound_port = match port {
199 0 => {
200 servers_locked.keys().max().unwrap_or(&0u16) + 1
201 }
202 _ => port
203 };
204
205 let address = format!("{}", bound_port);
206
207 let (tx, rx) = mpsc::channel::<RpcMessage>(256);
212
213 let mem_transport_server = Arc::new(MemTransportServer {
215 port: bound_port,
216 address,
217 server_manager: self.server_manager.clone(),
218 command_rx: Arc::new(RwLock::new(Some(rx))),
219 state: Arc::new(RwLock::new(ServerState::Created)),
220 ddos_protection: self.ddos_protection.clone(),
221 });
222
223 servers_locked.insert(bound_port, Arc::new(ServerEntry { command_tx: tx }));
225
226 Ok(mem_transport_server)
227 }
228
229 async fn rpc(&self, address: &str, bytes: Bytes) -> Result<Bytes> {
230 let port: u16 = address.parse()?;
231
232 let server_entry = {
233 let servers = self.server_manager.servers.read();
234 let server_entry = servers.get(&port).ok_or_else(|| anyhow::anyhow!("no server found with port {}", port))?;
235 server_entry.clone()
236 };
237
238 let (response_tx, response_rx) = oneshot::channel();
243
244 let message = RpcMessage { caller_address: format!("mem:{}", Id::random()), bytes, response_tx };
246
247 server_entry.command_tx.send(message).await.map_err(|e| anyhow::anyhow!("failed to send request: {}", e))?;
251
252 response_rx.await.map_err(|_| anyhow::anyhow!("server disconnected before responding"))?
254 }
255}
256
257
258#[cfg(test)]
259mod tests {
260 use crate::transport::mem_transport::MemTransportFactory;
261 use crate::transport::bootstrap_provider::manual_bootstrap_provider::ManualBootstrapProvider;
262 use crate::transport::ddos::noop_ddos::NoopDdosProtection;
263 use std::sync::Arc;
264
265 #[tokio::test]
266 async fn rpc_test() -> anyhow::Result<()> {
267 let factory: Arc<dyn crate::transport::transport::TransportFactory> = Arc::new(MemTransportFactory::new(NoopDdosProtection::default(), ManualBootstrapProvider::default()));
268 crate::transport::transport::tests::rpc_test(factory).await
269 }
270
271 #[tokio::test]
272 async fn bind_port_zero_test() -> anyhow::Result<()> {
273 let factory: Arc<dyn crate::transport::transport::TransportFactory> = Arc::new(MemTransportFactory::new(NoopDdosProtection::default(), ManualBootstrapProvider::default()));
274 crate::transport::transport::tests::bind_port_zero_test(factory).await
275 }
276}