hashiverse_lib/transport/ddos/
mem_ddos.rs1use crate::transport::ddos::ddos::{DdosProtection, DdosScore};
16use log::warn;
17use moka::sync::Cache;
18use parking_lot::Mutex;
19use std::collections::HashMap;
20use std::sync::Arc;
21use std::time::Duration;
22
23pub struct MemDdosProtection {
33 score_threshold: f64,
34 decay_per_second: f64,
35 bad_request_penalty: f64,
36 max_connections_per_ip: usize,
37 scores: Cache<String, Arc<Mutex<DdosScore>>>,
38 connections: Mutex<HashMap<String, usize>>,
39}
40
41impl MemDdosProtection {
42 pub fn new(score_threshold: f64, decay_per_second: f64, bad_request_penalty: f64, max_connections_per_ip: usize) -> Self {
43 let idle_secs = if decay_per_second > 0.0 {
45 (score_threshold / decay_per_second * 2.0).ceil() as u64
46 } else {
47 3600 };
49 Self {
50 score_threshold,
51 decay_per_second,
52 bad_request_penalty,
53 max_connections_per_ip,
54 scores: Cache::builder().time_to_idle(Duration::from_secs(idle_secs)).build(),
55 connections: Mutex::new(HashMap::new()),
56 }
57 }
58
59 fn increment_score(&self, ip: &str, points: f64) -> f64 {
60 let entry = self.scores.get_with(ip.to_string(), || Arc::new(Mutex::new(DdosScore::new())));
61 entry.lock().increment(points, self.decay_per_second)
62 }
63
64 fn is_score_banned(&self, ip: &str) -> bool {
65 self.scores
66 .get(ip)
67 .map(|entry| entry.lock().current(self.decay_per_second) >= self.score_threshold)
68 .unwrap_or(false)
69 }
70}
71
72impl DdosProtection for MemDdosProtection {
73 fn allow_request(&self, ip: &str) -> bool {
74 self.increment_score(ip, 1.0) < self.score_threshold
75 }
76
77 fn report_bad_request(&self, ip: &str) {
78 let score = self.increment_score(ip, self.bad_request_penalty);
79 if score >= self.score_threshold {
80 warn!("DDoS: {} blocked (score={:.1})", ip, score);
81 }
82 }
83
84 fn try_acquire_connection(&self, ip: &str) -> bool {
85 if self.is_score_banned(ip) {
86 return false;
87 }
88 let mut connections = self.connections.lock();
89 let count = connections.entry(ip.to_string()).or_insert(0);
90 if *count >= self.max_connections_per_ip {
91 return false;
92 }
93 *count += 1;
94 true
95 }
96
97 fn release_connection(&self, ip: &str) {
98 let mut connections = self.connections.lock();
99 if let Some(count) = connections.get_mut(ip) {
100 *count = count.saturating_sub(1);
101 if *count == 0 {
102 connections.remove(ip);
103 }
104 }
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111 use crate::tools::config;
112 use crate::transport::ddos::ddos::DdosConnectionGuard;
113
114 fn make_ddos() -> Arc<MemDdosProtection> {
115 Arc::new(MemDdosProtection::new(
116 config::SERVER_DDOS_SCORE_THRESHOLD,
117 config::SERVER_DDOS_DECAY_PER_SECOND,
118 config::SERVER_DDOS_BAD_REQUEST_PENALTY,
119 config::SERVER_DDOS_MAX_CONNECTIONS_PER_IP,
120 ))
121 }
122
123 #[test]
124 fn connection_guard_limits_per_ip() {
125 let ddos = make_ddos();
126 let ip = "1.2.3.4";
127
128 let mut guards = vec![];
129 for _ in 0..config::SERVER_DDOS_MAX_CONNECTIONS_PER_IP {
130 let guard = DdosConnectionGuard::try_new(ddos.clone(), ip);
131 assert!(guard.is_some(), "should acquire slot within limit");
132 guards.push(guard.unwrap());
133 }
134
135 let over_limit = DdosConnectionGuard::try_new(ddos.clone(), ip);
136 assert!(over_limit.is_none(), "should be blocked at per-IP cap");
137
138 drop(guards.pop().unwrap());
140 let recovered = DdosConnectionGuard::try_new(ddos.clone(), ip);
141 assert!(recovered.is_some(), "should acquire after release");
142 }
143
144 #[test]
145 fn connection_guard_independent_ips() {
146 let ddos = make_ddos();
147
148 let guard_a = DdosConnectionGuard::try_new(ddos.clone(), "1.1.1.1");
149 let guard_b = DdosConnectionGuard::try_new(ddos.clone(), "2.2.2.2");
150
151 assert!(guard_a.is_some());
152 assert!(guard_b.is_some());
153 }
154
155 #[test]
156 fn banned_ip_cannot_acquire_connection() {
157 let ddos = Arc::new(MemDdosProtection::new(3.0, 0.0, 3.0, 8));
158 let ip = "1.2.3.4";
159
160 while ddos.allow_request(ip) {}
162
163 let guard = DdosConnectionGuard::try_new(ddos.clone(), ip);
164 assert!(guard.is_none(), "banned IP should not acquire a connection slot");
165 }
166
167 #[test]
168 fn guard_report_bad_request_delegates() {
169 let ddos = Arc::new(MemDdosProtection::new(100.0, 0.0, 5.0, 8));
170 let ip = "5.6.7.8";
171 let guard = DdosConnectionGuard::try_new(ddos.clone(), ip).unwrap();
172
173 guard.report_bad_request();
174 assert!(guard.allow_request());
176 }
177
178 #[test]
179 fn connection_count_drops_to_zero_after_all_guards_released() {
180 let ddos = make_ddos();
181 let ip = "9.9.9.9";
182
183 let guard = DdosConnectionGuard::try_new(ddos.clone(), ip).unwrap();
184 drop(guard);
185
186 let mut guards = vec![];
187 for _ in 0..config::SERVER_DDOS_MAX_CONNECTIONS_PER_IP {
188 guards.push(DdosConnectionGuard::try_new(ddos.clone(), ip).unwrap());
189 }
190 assert!(DdosConnectionGuard::try_new(ddos.clone(), ip).is_none());
191 }
192
193 #[test]
194 fn allow_request_returns_false_at_threshold() {
195 let threshold = 5.0;
197 let ddos = Arc::new(MemDdosProtection::new(threshold, 0.0, 1.0, 8));
198 let ip = "3.3.3.3";
199
200 for i in 0..4 {
201 assert!(ddos.allow_request(ip), "request {} of 5 should be allowed", i + 1);
202 }
203 assert!(!ddos.allow_request(ip), "request at threshold should be blocked");
205 assert!(!ddos.allow_request(ip), "subsequent requests must also be blocked");
206 }
207
208 #[test]
209 fn bad_request_penalty_causes_ban_faster_than_normal_requests() {
210 let threshold = 20.0;
211 let penalty = 10.0;
212 let ddos = Arc::new(MemDdosProtection::new(threshold, 0.0, penalty, 8));
213 let ip = "4.4.4.4";
214
215 ddos.report_bad_request(ip); ddos.report_bad_request(ip); assert!(!ddos.allow_request(ip), "IP should be banned after two penalty-weight bad requests");
219 assert!(DdosConnectionGuard::try_new(ddos.clone(), ip).is_none(), "banned IP must not acquire a connection");
220 }
221
222 #[test]
223 fn score_is_independent_per_ip() {
224 let threshold = 3.0;
225 let ddos = Arc::new(MemDdosProtection::new(threshold, 0.0, 1.0, 8));
226 let ip_a = "10.0.0.1";
227 let ip_b = "10.0.0.2";
228
229 while ddos.allow_request(ip_a) {}
230 assert!(!ddos.allow_request(ip_a), "ip_a should be blocked");
231
232 assert!(ddos.allow_request(ip_b), "ip_b should be unaffected by ip_a's exhaustion");
233 let guard_b = DdosConnectionGuard::try_new(ddos.clone(), ip_b);
234 assert!(guard_b.is_some(), "ip_b should still acquire a connection after ip_a is banned");
235 }
236
237 #[test]
238 fn score_decays_over_time() {
239 let ddos = Arc::new(MemDdosProtection::new(5.0, 1000.0, 1.0, 8));
241 let ip = "7.7.7.7";
242
243 for _ in 0..4 {
245 ddos.allow_request(ip);
246 }
247
248 std::thread::sleep(std::time::Duration::from_millis(10));
251 assert!(ddos.allow_request(ip), "score should have decayed, allowing the request");
252 }
253}