Skip to main content

hashiverse_lib/transport/ddos/
mem_ddos.rs

1//! # In-memory DDoS accounting
2//!
3//! Implements [`crate::transport::ddos::ddos::DdosProtection`] purely in RAM: per-IP
4//! `DdosScore`s live in a `moka` cache with time-based eviction so idle IPs get
5//! collected automatically, and per-IP connection counts live in a `HashMap` guarded
6//! by a `parking_lot::Mutex`.
7//!
8//! "Ban" here is just a flag in the cache — no kernel-level dropping. That makes this
9//! implementation suitable for tests (the integration harness stresses the scoring
10//! logic without wanting to touch host firewall state) and for platforms where
11//! `ipset`/`iptables` aren't available. The production path in
12//! `hashiverse-server-lib` wraps this with a real firewall-level ban via
13//! [`crate::tools::config::SERVER_DDOS_IPSET_SET_NAME`].
14
15use crate::transport::ddos::ddos::{DdosProtection, DdosScore};
16use log::warn;
17use moka::sync::Cache;
18use parking_lot::Mutex;
19use std::collections::HashMap;
20use std::sync::Arc;
21use std::time::Duration;
22
23/// In-memory DDoS protection with linearly decaying per-IP scores.
24///
25/// Each `allow_request` adds 1.0 point, each `report_bad_request` adds
26/// `bad_request_penalty` points.  Between calls the score drains at
27/// `decay_per_second` points/second, so sustained low-rate traffic stabilises
28/// well below the threshold while bursts trigger quickly.
29///
30/// Scores are stored in a moka cache whose idle expiry is long enough for any
31/// maxed-out score to fully decay, keeping memory bounded.
32pub struct MemDdosProtection {
33    score_threshold: f64,
34    decay_per_second: f64,
35    bad_request_penalty: f64,
36    max_connections_per_ip: usize,
37    scores: Cache<String, Arc<Mutex<DdosScore>>>,
38    connections: Mutex<HashMap<String, usize>>,
39}
40
41impl MemDdosProtection {
42    pub fn new(score_threshold: f64, decay_per_second: f64, bad_request_penalty: f64, max_connections_per_ip: usize) -> Self {
43        // Idle expiry: time for a maxed-out score to fully decay, with 2x margin
44        let idle_secs = if decay_per_second > 0.0 {
45            (score_threshold / decay_per_second * 2.0).ceil() as u64
46        } else {
47            3600 // fallback: 1 hour if no decay
48        };
49        Self {
50            score_threshold,
51            decay_per_second,
52            bad_request_penalty,
53            max_connections_per_ip,
54            scores: Cache::builder().time_to_idle(Duration::from_secs(idle_secs)).build(),
55            connections: Mutex::new(HashMap::new()),
56        }
57    }
58
59    fn increment_score(&self, ip: &str, points: f64) -> f64 {
60        let entry = self.scores.get_with(ip.to_string(), || Arc::new(Mutex::new(DdosScore::new())));
61        entry.lock().increment(points, self.decay_per_second)
62    }
63
64    fn is_score_banned(&self, ip: &str) -> bool {
65        self.scores
66            .get(ip)
67            .map(|entry| entry.lock().current(self.decay_per_second) >= self.score_threshold)
68            .unwrap_or(false)
69    }
70}
71
72impl DdosProtection for MemDdosProtection {
73    fn allow_request(&self, ip: &str) -> bool {
74        self.increment_score(ip, 1.0) < self.score_threshold
75    }
76
77    fn report_bad_request(&self, ip: &str) {
78        let score = self.increment_score(ip, self.bad_request_penalty);
79        if score >= self.score_threshold {
80            warn!("DDoS: {} blocked (score={:.1})", ip, score);
81        }
82    }
83
84    fn try_acquire_connection(&self, ip: &str) -> bool {
85        if self.is_score_banned(ip) {
86            return false;
87        }
88        let mut connections = self.connections.lock();
89        let count = connections.entry(ip.to_string()).or_insert(0);
90        if *count >= self.max_connections_per_ip {
91            return false;
92        }
93        *count += 1;
94        true
95    }
96
97    fn release_connection(&self, ip: &str) {
98        let mut connections = self.connections.lock();
99        if let Some(count) = connections.get_mut(ip) {
100            *count = count.saturating_sub(1);
101            if *count == 0 {
102                connections.remove(ip);
103            }
104        }
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111    use crate::tools::config;
112    use crate::transport::ddos::ddos::DdosConnectionGuard;
113
114    fn make_ddos() -> Arc<MemDdosProtection> {
115        Arc::new(MemDdosProtection::new(
116            config::SERVER_DDOS_SCORE_THRESHOLD,
117            config::SERVER_DDOS_DECAY_PER_SECOND,
118            config::SERVER_DDOS_BAD_REQUEST_PENALTY,
119            config::SERVER_DDOS_MAX_CONNECTIONS_PER_IP,
120        ))
121    }
122
123    #[test]
124    fn connection_guard_limits_per_ip() {
125        let ddos = make_ddos();
126        let ip = "1.2.3.4";
127
128        let mut guards = vec![];
129        for _ in 0..config::SERVER_DDOS_MAX_CONNECTIONS_PER_IP {
130            let guard = DdosConnectionGuard::try_new(ddos.clone(), ip);
131            assert!(guard.is_some(), "should acquire slot within limit");
132            guards.push(guard.unwrap());
133        }
134
135        let over_limit = DdosConnectionGuard::try_new(ddos.clone(), ip);
136        assert!(over_limit.is_none(), "should be blocked at per-IP cap");
137
138        // Release one slot — should unblock
139        drop(guards.pop().unwrap());
140        let recovered = DdosConnectionGuard::try_new(ddos.clone(), ip);
141        assert!(recovered.is_some(), "should acquire after release");
142    }
143
144    #[test]
145    fn connection_guard_independent_ips() {
146        let ddos = make_ddos();
147
148        let guard_a = DdosConnectionGuard::try_new(ddos.clone(), "1.1.1.1");
149        let guard_b = DdosConnectionGuard::try_new(ddos.clone(), "2.2.2.2");
150
151        assert!(guard_a.is_some());
152        assert!(guard_b.is_some());
153    }
154
155    #[test]
156    fn banned_ip_cannot_acquire_connection() {
157        let ddos = Arc::new(MemDdosProtection::new(3.0, 0.0, 3.0, 8));
158        let ip = "1.2.3.4";
159
160        // Exhaust the score to trigger a ban
161        while ddos.allow_request(ip) {}
162
163        let guard = DdosConnectionGuard::try_new(ddos.clone(), ip);
164        assert!(guard.is_none(), "banned IP should not acquire a connection slot");
165    }
166
167    #[test]
168    fn guard_report_bad_request_delegates() {
169        let ddos = Arc::new(MemDdosProtection::new(100.0, 0.0, 5.0, 8));
170        let ip = "5.6.7.8";
171        let guard = DdosConnectionGuard::try_new(ddos.clone(), ip).unwrap();
172
173        guard.report_bad_request();
174        // After one bad-request penalty the score is ~5 — still under 100, so allow_request works
175        assert!(guard.allow_request());
176    }
177
178    #[test]
179    fn connection_count_drops_to_zero_after_all_guards_released() {
180        let ddos = make_ddos();
181        let ip = "9.9.9.9";
182
183        let guard = DdosConnectionGuard::try_new(ddos.clone(), ip).unwrap();
184        drop(guard);
185
186        let mut guards = vec![];
187        for _ in 0..config::SERVER_DDOS_MAX_CONNECTIONS_PER_IP {
188            guards.push(DdosConnectionGuard::try_new(ddos.clone(), ip).unwrap());
189        }
190        assert!(DdosConnectionGuard::try_new(ddos.clone(), ip).is_none());
191    }
192
193    #[test]
194    fn allow_request_returns_false_at_threshold() {
195        // Use zero decay so timing doesn't affect the test
196        let threshold = 5.0;
197        let ddos = Arc::new(MemDdosProtection::new(threshold, 0.0, 1.0, 8));
198        let ip = "3.3.3.3";
199
200        for i in 0..4 {
201            assert!(ddos.allow_request(ip), "request {} of 5 should be allowed", i + 1);
202        }
203        // 5th call reaches the limit
204        assert!(!ddos.allow_request(ip), "request at threshold should be blocked");
205        assert!(!ddos.allow_request(ip), "subsequent requests must also be blocked");
206    }
207
208    #[test]
209    fn bad_request_penalty_causes_ban_faster_than_normal_requests() {
210        let threshold = 20.0;
211        let penalty = 10.0;
212        let ddos = Arc::new(MemDdosProtection::new(threshold, 0.0, penalty, 8));
213        let ip = "4.4.4.4";
214
215        ddos.report_bad_request(ip); // score = 10
216        ddos.report_bad_request(ip); // score = 20 — at threshold
217
218        assert!(!ddos.allow_request(ip), "IP should be banned after two penalty-weight bad requests");
219        assert!(DdosConnectionGuard::try_new(ddos.clone(), ip).is_none(), "banned IP must not acquire a connection");
220    }
221
222    #[test]
223    fn score_is_independent_per_ip() {
224        let threshold = 3.0;
225        let ddos = Arc::new(MemDdosProtection::new(threshold, 0.0, 1.0, 8));
226        let ip_a = "10.0.0.1";
227        let ip_b = "10.0.0.2";
228
229        while ddos.allow_request(ip_a) {}
230        assert!(!ddos.allow_request(ip_a), "ip_a should be blocked");
231
232        assert!(ddos.allow_request(ip_b), "ip_b should be unaffected by ip_a's exhaustion");
233        let guard_b = DdosConnectionGuard::try_new(ddos.clone(), ip_b);
234        assert!(guard_b.is_some(), "ip_b should still acquire a connection after ip_a is banned");
235    }
236
237    #[test]
238    fn score_decays_over_time() {
239        // High threshold, fast decay: score should drop below threshold quickly
240        let ddos = Arc::new(MemDdosProtection::new(5.0, 1000.0, 1.0, 8));
241        let ip = "7.7.7.7";
242
243        // Add 4 points (just under threshold)
244        for _ in 0..4 {
245            ddos.allow_request(ip);
246        }
247
248        // With decay_per_second=1000, even a microsecond decays significantly
249        // Next request should be allowed because the score has decayed
250        std::thread::sleep(std::time::Duration::from_millis(10));
251        assert!(ddos.allow_request(ip), "score should have decayed, allowing the request");
252    }
253}