hashiverse_server_lib/transport/
full_https_transport.rs1use crate::transport::https_transport_cert_refresher::HttpsTransportCertRefresher;
24use crate::tools::tools::get_public_ipv4;
25use anyhow::anyhow;
26use axum::body::Body;
27use axum::extract::{DefaultBodyLimit, Extension};
28use axum::http::{header, StatusCode, Uri};
29use axum::response::{IntoResponse, Response};
30use axum::{routing::get, Router};
31use bytes::Bytes;
32use futures::stream;
33use hashiverse_lib::tools::config;
34use hashiverse_lib::transport::ddos::ddos::{DdosConnectionGuard, DdosProtection};
35use hashiverse_lib::transport::transport::{IncomingRequest, ServerState, TransportFactory, TransportServer};
36use hyper::body::Incoming;
37use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
38use hyper_util::server::conn::auto::Builder as AutoBuilder;
39use log::{error, info, trace, warn};
40use parking_lot::RwLock;
41use rustls::ServerConfig;
42use std::convert::Infallible;
43use std::path::PathBuf;
44use std::sync::Arc;
45use std::time::Duration;
46use tokio::net::TcpListener;
47use tokio::sync::{mpsc, oneshot, Mutex, Semaphore};
48use tokio::task::JoinSet;
49use tokio_rustls::TlsAcceptor;
50use tokio_util::sync::CancellationToken;
51use tower::{Service, ServiceExt};
52use tower_http::cors::CorsLayer;
53use tower_http::timeout::RequestBodyTimeoutLayer;
54use hashiverse_lib::transport::bootstrap_provider::bootstrap_provider::BootstrapProvider;
55
56#[derive(Clone)]
63pub struct FullHttpsTransportFactory {
64 ddos_protection: Arc<dyn DdosProtection>,
65 https_transport_factory: hashiverse_lib::transport::partial_https_transport::PartialHttpsTransportFactory,
66}
67
68pub struct FullHttpsTransportServer {
69 base_path: String,
70 force_local_network: bool,
71 address: String,
72 ip: String,
73 port: u16,
74 listener: Arc<Mutex<Option<TcpListener>>>, state: Arc<RwLock<ServerState>>,
76 ddos_protection: Arc<dyn DdosProtection>,
77}
78
79impl FullHttpsTransportServer {
80 async fn new(base_path: &str, address: String, ip: String, port: u16, force_local_network: bool, listener: TcpListener, ddos_protection: Arc<dyn DdosProtection>) -> anyhow::Result<Self> {
81 Ok(FullHttpsTransportServer {
82 base_path: base_path.to_string(),
83 force_local_network,
84 address,
85 ip,
86 port,
87 listener: Arc::new(Mutex::new(Some(listener))),
88 state: Arc::new(RwLock::new(ServerState::Created)),
89 ddos_protection
90 })
91 }
92}
93
94#[async_trait::async_trait]
95impl TransportServer for FullHttpsTransportServer {
96 fn get_address(&self) -> &String {
97 &self.address
98 }
99
100 async fn listen(&self, cancellation_token: CancellationToken, handler: mpsc::Sender<IncomingRequest>) -> anyhow::Result<()> {
101 {
103 let mut state = self.state.write();
104 match *state {
105 ServerState::Listening => {
106 anyhow::bail!("server is already listening");
107 }
108 ServerState::Shutdown => {
109 anyhow::bail!("server has been shut down");
110 }
111 ServerState::Created => {
112 *state = ServerState::Listening;
113 }
114 }
115 }
116
117 info!("listening on address {}", self.address);
118
119 let mut listener = self.listener.lock().await;
120 let listener = match listener.take() {
121 Some(listener) => listener,
122 None => {
123 return Err(anyhow!("listener had already been taken"));
124 }
125 };
126
127 let handler_clone = handler.clone();
130 let handle_blob = move |Extension(ddos_connection_guard): Extension<Arc<DdosConnectionGuard>>, bytes: Bytes| async move {
131 let handler = handler_clone.clone();
132
133 if !ddos_connection_guard.allow_request() {
134 trace!("DDoS: request from {} blocked", ddos_connection_guard.ip());
135 return Err(StatusCode::TOO_MANY_REQUESTS);
136 }
137
138 let caller_address = ddos_connection_guard.ip().to_string();
139
140 let result: anyhow::Result<Response<axum::body::Body>> = try {
141 let (reply_tx, reply_rx) = oneshot::channel();
142 handler.send(IncomingRequest::new(caller_address, bytes, reply_tx, ddos_connection_guard.clone())).await.map_err(|e| anyhow::anyhow!("Failed to send message: {}", e))?;
143 let response = reply_rx.await.map_err(|e| anyhow::anyhow!("Failed to receive message: {}", e))?;
144
145 let content_length = response.len();
147 let segments = response.compact(config::TRANSPORT_BYTES_GATHERER_COMPACT_THRESHOLD).finish();
148 let body = axum::body::Body::from_stream(stream::iter(segments.into_iter().map(Ok::<Bytes, Infallible>)));
149
150 let response = axum::http::Response::builder()
151 .status(StatusCode::OK)
152 .header(header::CONTENT_TYPE, "application/octet-stream")
153 .header(header::CONTENT_LENGTH, content_length)
154 .body(body)
155 .map_err(|e| anyhow::anyhow!("Failed to build response: {}", e))?;
156
157 response
158 };
159
160 match result {
161 Ok(response) => Ok(response.into_response()),
162
163 Err(e) => {
164 warn!("error processing blob: {}", e);
165 ddos_connection_guard.report_bad_request();
166 Err(StatusCode::BAD_REQUEST)
167 }
168 }
169 };
170
171 let fallback_handler = move |Extension(ddos_connection_guard): Extension<Arc<DdosConnectionGuard>>, uri: Uri| {
172 async move {
173 trace!("unhandled route for path: {} from {}", uri, ddos_connection_guard.ip());
174 ddos_connection_guard.report_bad_request();
175 StatusCode::NOT_FOUND
176 }
177 };
178
179 let axum_app = Router::new()
180 .route("/", get(|| async { "Hashiverse!" }).post(handle_blob))
181 .layer(DefaultBodyLimit::max(config::PROTOCOL_MAX_BLOB_SIZE_REQUEST))
182 .layer(RequestBodyTimeoutLayer::new(Duration::from_secs(config::HTTPS_SERVER_TRANSPORT_BODY_READ_TIMEOUT_SECS)))
183 .layer(CorsLayer::permissive())
184 .fallback(fallback_handler);
185
186 let path_certs = PathBuf::from(self.base_path.clone()).join("certs");
187 let cert_refresher = Arc::new(HttpsTransportCertRefresher::new(path_certs.clone(), self.ip.clone(), self.port, self.force_local_network)?);
188 cert_refresher.reload_certs()?;
189
190 let tls_acceptor = {
191 let mut server_config = ServerConfig::builder().with_no_client_auth().with_cert_resolver(cert_refresher.clone());
192 server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec(), b"acme-tls/1".to_vec()];
193 TlsAcceptor::from(Arc::new(server_config))
194 };
195
196 let mut make_service = axum_app.into_make_service_with_connect_info::<std::net::SocketAddr>();
197 let connection_semaphore = Arc::new(Semaphore::new(config::HTTPS_SERVER_TRANSPORT_MAX_CONNECTIONS));
198 let mut join_set: JoinSet<()> = JoinSet::new();
199
200 let ddos = self.ddos_protection.clone();
201
202 let accept_loop = async {
203 loop {
204 while join_set.try_join_next().is_some() {}
206
207 tokio::select! {
208 accept_result = listener.accept() => {
209 let (tcp_stream, peer_addr) = match accept_result {
210 Ok(v) => v,
211 Err(e) => { warn!("accept error: {}", e); continue; }
212 };
213 let ip = peer_addr.ip().to_string();
214
215 let ddos_connection_guard = match DdosConnectionGuard::try_new(ddos.clone(), ip.clone()) {
219 Some(guard) => Arc::new(guard),
220 None => {
221 trace!("DDoS: dropping connection from {} (blocked or per-IP cap reached)", ip);
222 continue;
223 }
224 };
225
226 let permit = match Arc::clone(&connection_semaphore).try_acquire_owned() {
228 Ok(p) => p,
229 Err(_) => {
230 warn!("connection cap ({}) reached, dropping {}", config::HTTPS_SERVER_TRANSPORT_MAX_CONNECTIONS, ip);
231 continue;
232 }
233 };
234
235 let tower_service = match make_service.call(peer_addr).await {
239 Ok(s) => s,
240 Err(e) => { warn!("make_service error for {}: {:?}", ip, e); continue; }
241 };
242
243 let tls_acceptor = tls_acceptor.clone();
244
245 join_set.spawn(async move {
246 let _permit = permit; let tls_stream = match tokio::time::timeout(
251 Duration::from_secs(config::HTTPS_SERVER_TRANSPORT_TLS_HANDSHAKE_TIMEOUT_SECS),
252 tls_acceptor.accept(tcp_stream),
253 ).await {
254 Ok(Ok(s)) => s,
255 Ok(Err(e)) => { trace!("TLS error from {}: {}", ip, e); ddos_connection_guard.report_bad_request(); return; }
256 Err(_) => { trace!("TLS handshake timeout from {}", ip); ddos_connection_guard.report_bad_request(); return; }
257 };
258
259 let io = TokioIo::new(tls_stream);
260
261 let hyper_service = hyper::service::service_fn(move |mut req: hyper::Request<Incoming>| {
266 req.extensions_mut().insert(ddos_connection_guard.clone());
267 tower_service.clone().oneshot(req.map(Body::new))
268 });
269
270 let mut auto_builder = AutoBuilder::new(TokioExecutor::new());
274 auto_builder.http1()
275 .timer(TokioTimer::new())
276 .header_read_timeout(Duration::from_secs(config::HTTPS_SERVER_TRANSPORT_HEADER_READ_TIMEOUT_SECS));
277
278 if let Err(e) = auto_builder.serve_connection(io, hyper_service).await {
279 trace!("connection error from {}: {}", ip, e);
280 }
281 });
282 }
283 _ = cancellation_token.cancelled() => break,
284 }
285 }
286
287 let shutdown_deadline = tokio::time::sleep(Duration::from_secs(config::HTTPS_SERVER_TRANSPORT_SHUTDOWN_TIMEOUT_SECS));
289 tokio::pin!(shutdown_deadline);
290 loop {
291 tokio::select! {
292 result = join_set.join_next() => {
293 match result {
294 None => break,
295 Some(Err(e)) => warn!("connection task error during shutdown: {}", e),
296 Some(Ok(())) => {}
297 }
298 }
299 _ = &mut shutdown_deadline => {
300 join_set.abort_all();
301 break;
302 }
303 }
304 }
305
306 anyhow::Ok(())
307 };
308
309 let results = tokio::join!(
311 accept_loop,
312 cert_refresher.process(cancellation_token.clone()),
313 );
314
315 if let Err(e) = results.0 {
316 error!("error in accept loop: {}", e)
317 }
318 if let Err(e) = results.1 {
319 error!("error in cert refresher: {}", e)
320 }
321
322 info!("stopped listening on address {}", self.address);
323 info!("all open connections complete");
324 *self.state.write() = ServerState::Shutdown;
325
326 Ok(())
327 }
328}
329
330impl FullHttpsTransportFactory {
331 pub fn new(ddos_protection: Arc<dyn DdosProtection>, bootstrap_provider: Arc<dyn BootstrapProvider>) -> Self {
332 let https_transport_factory = hashiverse_lib::transport::partial_https_transport::PartialHttpsTransportFactory::new(bootstrap_provider);
333 Self { ddos_protection, https_transport_factory }
334 }
335}
336
337#[async_trait::async_trait]
338impl TransportFactory for FullHttpsTransportFactory {
339 async fn get_bootstrap_addresses(&self) -> Vec<String> {
340 self.https_transport_factory.get_bootstrap_addresses().await
341 }
342
343 async fn create_server(&self, base_path: &str, port: u16, force_local_network: bool) -> anyhow::Result<Arc<dyn TransportServer>> {
344 let address_to_bind = format!("0.0.0.0:{}", port);
349 info!("bind on: {}", address_to_bind);
350 let listener = TcpListener::bind(address_to_bind).await?;
351
352 let address_bound_ip = get_public_ipv4(force_local_network).await?;
353 let address_bound_port = listener.local_addr()?.port();
354 let address = format!("{}:{}", address_bound_ip, address_bound_port);
355
356 let http_transport_server: Arc<dyn TransportServer> = Arc::new(FullHttpsTransportServer::new(base_path, address, address_bound_ip, address_bound_port, force_local_network, listener, self.ddos_protection.clone()).await?);
357 Ok(http_transport_server)
358 }
359
360 async fn rpc(&self, address: &str, bytes: Bytes) -> anyhow::Result<Bytes> {
361 self.https_transport_factory.rpc(address, bytes).await
362 }
363}
364
365
366#[cfg(test)]
367mod tests {
368 use crate::transport::full_https_transport::FullHttpsTransportFactory;
369 use hashiverse_lib::transport::bootstrap_provider::manual_bootstrap_provider::ManualBootstrapProvider;
370 use hashiverse_lib::transport::ddos::noop_ddos::NoopDdosProtection;
371 use hashiverse_lib::transport::transport::TransportFactory;
372 use std::sync::Arc;
373
374 #[tokio::test]
375 async fn rpc_test() -> anyhow::Result<()> {
376 let factory: Arc<dyn TransportFactory> = Arc::new(FullHttpsTransportFactory::new(NoopDdosProtection::default(), ManualBootstrapProvider::default()));
377 hashiverse_lib::transport::transport::tests::rpc_test(factory).await
378 }
379
380 #[tokio::test]
381 async fn bind_port_zero_test() -> anyhow::Result<()> {
382 let factory: Arc<dyn TransportFactory> = Arc::new(FullHttpsTransportFactory::new(NoopDdosProtection::default(), ManualBootstrapProvider::default()));
383 hashiverse_lib::transport::transport::tests::bind_port_zero_test(factory).await
384 }
385}