hashiverse_server_lib/transport/
tcp_transport.rs1use crate::tools::tools::get_public_ipv4;
15use anyhow::anyhow;
16use bytes::Bytes;
17use futures::{SinkExt, StreamExt};
18use hashiverse_lib::tools::config;
19use hashiverse_lib::transport::ddos::ddos::{DdosConnectionGuard, DdosProtection};
20use hashiverse_lib::transport::transport::{IncomingRequest, ServerState, TransportFactory, TransportServer};
21use log::{info, trace, warn};
22use parking_lot::RwLock;
23use std::net::SocketAddr;
24use std::sync::Arc;
25use std::time::Duration;
26use tokio::net::{TcpListener, TcpStream};
27use tokio::sync::{Mutex, mpsc, oneshot};
28use tokio::time::sleep;
29use tokio_util::codec::{Framed, LengthDelimitedCodec};
30use tokio_util::sync::CancellationToken;
31use tokio_util::task::TaskTracker;
32use hashiverse_lib::transport::bootstrap_provider::bootstrap_provider::BootstrapProvider;
33
34#[derive(Clone)]
35pub struct TcpTransportFactory {
36 ddos_protection: Arc<dyn DdosProtection>,
37 bootstrap_provider: Arc<dyn BootstrapProvider>,
38}
39
40impl TcpTransportFactory {
41 pub fn new(ddos_protection: Arc<dyn DdosProtection>, bootstrap_provider: Arc<dyn BootstrapProvider>) -> Self {
42 Self { ddos_protection, bootstrap_provider }
43 }
44}
45
46pub struct TcpTransportServer {
47 address: String,
48 listener: Arc<Mutex<TcpListener>>,
49 state: Arc<RwLock<ServerState>>,
50 ddos_protection: Arc<dyn DdosProtection>,
51}
52
53impl TcpTransportServer {
54 async fn new(address: String, listener: TcpListener, ddos_protection: Arc<dyn DdosProtection>) -> anyhow::Result<Self> {
55 Ok(TcpTransportServer {
56 address,
57 listener: Arc::new(Mutex::new(listener)),
58 state: Arc::new(RwLock::new(ServerState::Created)),
59 ddos_protection,
60 })
61 }
62}
63
64#[async_trait::async_trait]
65impl TransportServer for TcpTransportServer {
66 fn get_address(&self) -> &String {
67 &self.address
68 }
69
70 async fn listen(&self, cancellation_token: CancellationToken, handler: mpsc::Sender<IncomingRequest>) -> anyhow::Result<()> {
71 {
73 let mut state = self.state.write();
74 match *state {
75 ServerState::Listening => {
76 anyhow::bail!("server is already listening");
77 }
78 ServerState::Shutdown => {
79 anyhow::bail!("server has been shut down");
80 }
81 ServerState::Created => {
82 *state = ServerState::Listening;
83 }
84 }
85 }
86
87 async fn process_connection(cancellation_token: CancellationToken, handler: mpsc::Sender<IncomingRequest>, socket: TcpStream, socket_addr: SocketAddr, ddos_protection: Arc<dyn DdosProtection>) -> anyhow::Result<()> {
88 let ip = socket_addr.ip().to_string();
92 let ddos_connection_guard = match DdosConnectionGuard::try_new(ddos_protection, &ip) {
93 Some(guard) => Arc::new(guard),
94 None => {
95 trace!("DDoS: dropping TCP connection from {}", ip);
96 return Ok(());
97 }
98 };
99 let caller_address = ddos_connection_guard.ip().to_string();
100 let mut framed = LengthDelimitedCodec::builder().max_frame_length(config::PROTOCOL_MAX_BLOB_SIZE_REQUEST).new_framed(socket);
101
102 let result = tokio::select! {
103 _ = cancellation_token.cancelled() => { return Err(anyhow!("cancelled")) },
104
105 _ = sleep(Duration::from_secs(2)) => {
106 Err(anyhow::anyhow!("timeout waiting for request"))
107 },
108
109 next = framed.next() => {
110 match next {
111 None => Ok(()),
112 Some(Ok(bytes)) => {
113 let (reply_tx, reply_rx) = oneshot::channel();
115 handler.send(IncomingRequest::new(caller_address, bytes.into(), reply_tx, ddos_connection_guard)).await?;
116 let response = reply_rx.await?;
117 framed.send(response.to_bytes()).await?;
118 Ok(())
119 },
120 Some(Err(e)) => Err(anyhow!("error reading string from framed stream: {}", e)),
121 }
122 }
123 };
124
125 if let Err(e) = result {
126 warn!("error processing connection: {}", e);
127 }
128
129 Ok(())
130 }
131
132 let task_tracker = TaskTracker::new();
133
134 info!("listening on address {}", self.address);
135
136 loop {
137 let listener = self.listener.lock().await;
138
139 tokio::select! {
140 _ = cancellation_token.cancelled() => {
141 break;
142 },
143 Ok((socket, socket_addr)) = listener.accept() => {
144 task_tracker.spawn(
145 process_connection(cancellation_token.clone(), handler.clone(), socket, socket_addr, self.ddos_protection.clone())
146 );
147 },
148 }
149 }
150
151 info!("stopped listening on address {}", self.address);
153 drop(self.listener.lock().await);
154
155 info!("waiting for open connections to complete");
157 task_tracker.close();
158 task_tracker.wait().await;
159
160 info!("all open connections complete");
162 *self.state.write() = ServerState::Shutdown;
163
164 Ok(())
165 }
166}
167
168#[async_trait::async_trait]
169impl TransportFactory for TcpTransportFactory {
170 async fn get_bootstrap_addresses(&self) -> Vec<String> {
171 self.bootstrap_provider.get_bootstrap_addresses().await
172 }
173
174 async fn create_server(&self, _base_path: &str, port: u16, force_local_network: bool) -> anyhow::Result<Arc<dyn TransportServer>> {
175 let address_to_bind = format!("0.0.0.0:{}", port);
177 info!("bind on: {}", address_to_bind);
178 let listener = TcpListener::bind(address_to_bind).await?;
179
180 let address_bound_ip = get_public_ipv4(force_local_network).await?;
181 let address_bound_port = listener.local_addr()?.port();
182 let address = format!("{}:{}", address_bound_ip, address_bound_port);
183
184 let tcp_transport_server = Arc::new(TcpTransportServer::new(address, listener, self.ddos_protection.clone()).await?);
185 Ok(tcp_transport_server)
186 }
187
188 async fn rpc(&self, address: &str, bytes: Bytes) -> anyhow::Result<Bytes> {
189 let stream = TcpStream::connect(address).await?;
190 let mut framed: Framed<TcpStream, LengthDelimitedCodec> = Framed::new(stream, LengthDelimitedCodec::new());
194 framed.send(bytes).await?;
195
196 trace!("awaiting response");
198 tokio::select! {
199 _ = sleep(Duration::from_secs(2)) => {
200 trace!("timeout");
201 Err(anyhow::anyhow!("timeout waiting for response"))
202 },
203
204 next_frame = framed.next() => {
205 match next_frame {
206 Some(Ok(bytes)) => {
207 Ok(bytes.into())
208 }
209 Some(Err(e)) => {
210 Err(anyhow::anyhow!("error reading response: {}", e)) },
211 None => {
212 Err(anyhow::anyhow!("no response")) },
213 }
214 }
215 }
216 }
217}
218
219
220#[cfg(test)]
221mod tests {
222 use crate::transport::tcp_transport::TcpTransportFactory;
223 use hashiverse_lib::transport::bootstrap_provider::manual_bootstrap_provider::ManualBootstrapProvider;
224 use hashiverse_lib::transport::ddos::noop_ddos::NoopDdosProtection;
225 use hashiverse_lib::transport::transport::TransportFactory;
226 use std::sync::Arc;
227
228 #[tokio::test]
229 async fn rpc_test() -> anyhow::Result<()> {
230 let factory: Arc<dyn TransportFactory> = Arc::new(TcpTransportFactory::new(NoopDdosProtection::default(), ManualBootstrapProvider::default()));
231 hashiverse_lib::transport::transport::tests::rpc_test(factory).await
232 }
233
234 #[tokio::test]
235 async fn bind_port_zero_test() -> anyhow::Result<()> {
236 let factory: Arc<dyn TransportFactory> = Arc::new(TcpTransportFactory::new(NoopDdosProtection::default(), ManualBootstrapProvider::default()));
237 hashiverse_lib::transport::transport::tests::bind_port_zero_test(factory).await
238 }
239}