Skip to main content

hashiverse_server_lib/transport/
tcp_transport.rs

1//! # Plain-text TCP transport
2//!
3//! An unencrypted transport for local testing and private-LAN deployments where TLS
4//! is unnecessary. Frames requests and responses with `tokio-util`'s
5//! `LengthDelimitedCodec` — each message is prefixed with a u32 length, so there's
6//! no application-level ambiguity about where one message ends and the next begins.
7//!
8//! Uses the same pluggable
9//! [`hashiverse_lib::transport::ddos::ddos::DdosProtection`] trait as the HTTPS
10//! transport, so `NoopDdosProtection`, `MemDdos`, or the ipset-backed protection can
11//! all drop in unchanged. Per-request timeout is 2 seconds; anything slower is
12//! considered either a buggy client or a slow-loris probe.
13
14use crate::tools::tools::get_public_ipv4;
15use anyhow::anyhow;
16use bytes::Bytes;
17use futures::{SinkExt, StreamExt};
18use hashiverse_lib::tools::config;
19use hashiverse_lib::transport::ddos::ddos::{DdosConnectionGuard, DdosProtection};
20use hashiverse_lib::transport::transport::{IncomingRequest, ServerState, TransportFactory, TransportServer};
21use log::{info, trace, warn};
22use parking_lot::RwLock;
23use std::net::SocketAddr;
24use std::sync::Arc;
25use std::time::Duration;
26use tokio::net::{TcpListener, TcpStream};
27use tokio::sync::{Mutex, mpsc, oneshot};
28use tokio::time::sleep;
29use tokio_util::codec::{Framed, LengthDelimitedCodec};
30use tokio_util::sync::CancellationToken;
31use tokio_util::task::TaskTracker;
32use hashiverse_lib::transport::bootstrap_provider::bootstrap_provider::BootstrapProvider;
33
34#[derive(Clone)]
35pub struct TcpTransportFactory {
36    ddos_protection: Arc<dyn DdosProtection>,
37    bootstrap_provider: Arc<dyn BootstrapProvider>,
38}
39
40impl TcpTransportFactory {
41    pub fn new(ddos_protection: Arc<dyn DdosProtection>, bootstrap_provider: Arc<dyn BootstrapProvider>) -> Self {
42        Self { ddos_protection, bootstrap_provider }
43    }
44}
45
46pub struct TcpTransportServer {
47    address: String,
48    listener: Arc<Mutex<TcpListener>>,
49    state: Arc<RwLock<ServerState>>,
50    ddos_protection: Arc<dyn DdosProtection>,
51}
52
53impl TcpTransportServer {
54    async fn new(address: String, listener: TcpListener, ddos_protection: Arc<dyn DdosProtection>) -> anyhow::Result<Self> {
55        Ok(TcpTransportServer {
56            address,
57            listener: Arc::new(Mutex::new(listener)),
58            state: Arc::new(RwLock::new(ServerState::Created)),
59            ddos_protection,
60        })
61    }
62}
63
64#[async_trait::async_trait]
65impl TransportServer for TcpTransportServer {
66    fn get_address(&self) -> &String {
67        &self.address
68    }
69
70    async fn listen(&self, cancellation_token: CancellationToken, handler: mpsc::Sender<IncomingRequest>) -> anyhow::Result<()> {
71        // Check that we can transition to listening
72        {
73            let mut state = self.state.write();
74            match *state {
75                ServerState::Listening => {
76                    anyhow::bail!("server is already listening");
77                }
78                ServerState::Shutdown => {
79                    anyhow::bail!("server has been shut down");
80                }
81                ServerState::Created => {
82                    *state = ServerState::Listening;
83                }
84            }
85        }
86
87        async fn process_connection(cancellation_token: CancellationToken, handler: mpsc::Sender<IncomingRequest>, socket: TcpStream, socket_addr: SocketAddr, ddos_protection: Arc<dyn DdosProtection>) -> anyhow::Result<()> {
88            // trace!("accepted connection on: {socket_addr}");
89            // defer! { trace!("dropped connection from: {socket_addr}"); }
90
91            let ip = socket_addr.ip().to_string();
92            let ddos_connection_guard = match DdosConnectionGuard::try_new(ddos_protection, &ip) {
93                Some(guard) => Arc::new(guard),
94                None => {
95                    trace!("DDoS: dropping TCP connection from {}", ip);
96                    return Ok(());
97                }
98            };
99            let caller_address = ddos_connection_guard.ip().to_string();
100            let mut framed = LengthDelimitedCodec::builder().max_frame_length(config::PROTOCOL_MAX_BLOB_SIZE_REQUEST).new_framed(socket);
101
102            let result = tokio::select! {
103                _ = cancellation_token.cancelled() => { return Err(anyhow!("cancelled")) },
104
105                _ = sleep(Duration::from_secs(2)) => {
106                    Err(anyhow::anyhow!("timeout waiting for request"))
107                },
108
109                next = framed.next() => {
110                    match next {
111                        None => Ok(()),
112                        Some(Ok(bytes)) => {
113                            // trace!("received bytes={:?}", bytes);
114                            let (reply_tx, reply_rx) = oneshot::channel();
115                            handler.send(IncomingRequest::new(caller_address, bytes.into(), reply_tx, ddos_connection_guard)).await?;
116                            let response = reply_rx.await?;
117                            framed.send(response.to_bytes()).await?;
118                            Ok(())
119                        },
120                        Some(Err(e)) => Err(anyhow!("error reading string from framed stream: {}", e)),
121                    }
122                }
123            };
124
125            if let Err(e) = result {
126                warn!("error processing connection: {}", e);
127            }
128
129            Ok(())
130        }
131
132        let task_tracker = TaskTracker::new();
133
134        info!("listening on address {}", self.address);
135
136        loop {
137            let listener = self.listener.lock().await;
138
139            tokio::select! {
140                _ = cancellation_token.cancelled() => {
141                    break;
142                },
143                Ok((socket, socket_addr)) = listener.accept() => {
144                    task_tracker.spawn(
145                        process_connection(cancellation_token.clone(), handler.clone(), socket, socket_addr, self.ddos_protection.clone())
146                    );
147                },
148            }
149        }
150
151        // Stop accepting new connections
152        info!("stopped listening on address {}", self.address);
153        drop(self.listener.lock().await);
154
155        // Wait for existing connections to complete
156        info!("waiting for open connections to complete");
157        task_tracker.close();
158        task_tracker.wait().await;
159
160        // Notify the "shutdown" coroutine that we have successfully shutdown
161        info!("all open connections complete");
162        *self.state.write() = ServerState::Shutdown;
163
164        Ok(())
165    }
166}
167
168#[async_trait::async_trait]
169impl TransportFactory for TcpTransportFactory {
170    async fn get_bootstrap_addresses(&self) -> Vec<String> {
171        self.bootstrap_provider.get_bootstrap_addresses().await
172    }
173
174    async fn create_server(&self, _base_path: &str, port: u16, force_local_network: bool) -> anyhow::Result<Arc<dyn TransportServer>> {
175        // Deliberately IPv4-only.  See https_transport.rs for the reasoning.
176        let address_to_bind = format!("0.0.0.0:{}", port);
177        info!("bind on: {}", address_to_bind);
178        let listener = TcpListener::bind(address_to_bind).await?;
179
180        let address_bound_ip = get_public_ipv4(force_local_network).await?;
181        let address_bound_port = listener.local_addr()?.port();
182        let address = format!("{}:{}", address_bound_ip, address_bound_port);
183
184        let tcp_transport_server = Arc::new(TcpTransportServer::new(address, listener, self.ddos_protection.clone()).await?);
185        Ok(tcp_transport_server)
186    }
187
188    async fn rpc(&self, address: &str, bytes: Bytes) -> anyhow::Result<Bytes> {
189        let stream = TcpStream::connect(address).await?;
190        // trace!("connected to: {}", address.address);
191        // defer! { trace!("disconnected from: {}", &address.address); }
192
193        let mut framed: Framed<TcpStream, LengthDelimitedCodec> = Framed::new(stream, LengthDelimitedCodec::new());
194        framed.send(bytes).await?;
195
196        // Return the response
197        trace!("awaiting response");
198        tokio::select! {
199            _ = sleep(Duration::from_secs(2)) => {
200                trace!("timeout");
201                Err(anyhow::anyhow!("timeout waiting for response"))
202            },
203
204            next_frame = framed.next() => {
205                match next_frame {
206                    Some(Ok(bytes)) => {
207                        Ok(bytes.into())
208                    }
209                    Some(Err(e)) => {
210                        Err(anyhow::anyhow!("error reading response: {}", e)) },
211                    None => {
212                        Err(anyhow::anyhow!("no response")) },
213                }
214           }
215        }
216    }
217}
218
219
220#[cfg(test)]
221mod tests {
222    use crate::transport::tcp_transport::TcpTransportFactory;
223    use hashiverse_lib::transport::bootstrap_provider::manual_bootstrap_provider::ManualBootstrapProvider;
224    use hashiverse_lib::transport::ddos::noop_ddos::NoopDdosProtection;
225    use hashiverse_lib::transport::transport::TransportFactory;
226    use std::sync::Arc;
227
228    #[tokio::test]
229    async fn rpc_test() -> anyhow::Result<()> {
230        let factory: Arc<dyn TransportFactory> = Arc::new(TcpTransportFactory::new(NoopDdosProtection::default(), ManualBootstrapProvider::default()));
231        hashiverse_lib::transport::transport::tests::rpc_test(factory).await
232    }
233
234    #[tokio::test]
235    async fn bind_port_zero_test() -> anyhow::Result<()> {
236        let factory: Arc<dyn TransportFactory> = Arc::new(TcpTransportFactory::new(NoopDdosProtection::default(), ManualBootstrapProvider::default()));
237        hashiverse_lib::transport::transport::tests::bind_port_zero_test(factory).await
238    }
239}