1use crate::tools::json;
21use crate::tools::time::DurationMillis;
22use crate::tools::time_provider::time_provider::{RealTimeProvider, TimeProvider};
23use crate::tools::BytesGatherer;
24use argon2::password_hash::rand_core::{OsRng, RngCore};
25use base64::Engine;
26use bytes::Bytes;
27use log::info;
28use std::fmt;
29use std::future::Future;
30use std::sync::Arc;
31use tokio_util::sync::CancellationToken;
32use tracing_subscriber::fmt::time::FormatTime;
33use tracing_subscriber::layer::SubscriberExt;
34use tracing_subscriber::util::SubscriberInitExt;
35
36pub type LeadingAgreementBits = i32;
37
38pub async fn yield_now() {
39 #[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
40 {
41 send_wrapper::SendWrapper::new(gloo_timers::future::TimeoutFuture::new(0)).await;
42 }
43 #[cfg(all(target_arch = "wasm32", target_os = "wasi"))]
44 {
45 tokio::time::sleep(std::time::Duration::from_millis(0u64)).await;
46 }
47 #[cfg(not(target_arch = "wasm32"))]
48 {
49 tokio::task::yield_now().await;
51 }
52}
53
54pub fn random_fill_bytes(dest: &mut [u8]) {
55 OsRng.fill_bytes(dest);
56}
57
58pub fn random_bytes(n: usize) -> Vec<u8> {
59 let mut dest = vec![0u8; n];
60 random_fill_bytes(&mut dest);
61 dest
62}
63
64pub fn reverse_bytes<const N: usize>(bytes: &[u8; N]) -> [u8; N] {
65 let mut result = [0u8; N];
66 for (i, &byte) in bytes.iter().rev().enumerate() {
67 result[i] = byte;
68 }
69 result
70}
71
72pub fn random_u32() -> u32 {
73 OsRng.next_u32()
74}
75
76#[cfg(target_pointer_width = "64")]
77pub fn random_usize() -> usize {
78 OsRng.next_u64() as usize
79}
80
81#[cfg(target_pointer_width = "32")]
82pub fn random_usize() -> usize {
83 OsRng.next_u32() as usize
84}
85
86pub fn random_usize_bounded(upper: usize) -> usize {
87 let zone = usize::MAX - (usize::MAX % upper);
90
91 loop {
92 let r = random_usize();
93 if r < zone {
94 return r % upper;
95 }
96 }
97}
98
99pub fn random_u8() -> u8 {
100 OsRng.next_u32() as u8
101}
102
103pub fn random_base64(length: usize) -> String {
104 let mut bytes = vec![0u8; length];
105 random_fill_bytes(&mut bytes);
106 encode_base64(bytes)
107}
108
109pub fn are_all_zeros<T: PartialEq + num_traits::Zero>(src: &[T]) -> bool {
110 src.iter().all(|b| *b == T::zero())
111}
112
113pub fn are_all_equal<T: PartialEq>(src1: &[T], src2: &[T]) -> bool {
114 if src1.len() != src2.len() {
115 return false;
116 }
117 src1.iter().zip(src2).all(|(a, b)| a == b)
118}
119
120pub fn count_leading_zero_bits(bytes: &[u8]) -> u8 {
121 let mut count = 0u64;
122
123 for &byte in bytes {
124 if byte == 0 {
125 count += 8;
126 continue;
127 }
128
129 let mut mask = 0x80; while byte & mask == 0 {
132 count += 1;
133 mask >>= 1;
134 }
135
136 break; }
138
139 if count < 256 { count as u8 } else { 255 }
140}
141
142pub async fn cancellable_sleep_millis(time_provider: &dyn TimeProvider, millis: DurationMillis, cancellation_token: &CancellationToken) {
143 tokio::select! {
144 _ = time_provider.sleep_millis(millis) => {},
145 _ = cancellation_token.cancelled() => {},
146 }
147}
148
149pub fn format_vec<T: std::fmt::Display>(items: &[T]) -> String {
150 format!("[ {} ]", items.iter().map(|item| format!("{}", item)).collect::<Vec<_>>().join(", "))
151}
152
153pub fn leading_agreement_bits_xor(key1: &[u8], key2: &[u8]) -> LeadingAgreementBits {
154 let mut leading_bits_in_agreement: i32 = 0;
155
156 let min_len = std::cmp::min(key1.len(), key2.len());
157 for byte_idx in 0..min_len {
158 let xor = key1[byte_idx] ^ key2[byte_idx];
159
160 if xor != 0 {
162 leading_bits_in_agreement += xor.leading_zeros() as LeadingAgreementBits;
163 return leading_bits_in_agreement;
164 }
165 else {
166 leading_bits_in_agreement += 8;
167 }
168 }
169
170 leading_bits_in_agreement
171}
172
173pub fn encode_base64<T: AsRef<[u8]>>(input: T) -> String {
174 base64::engine::general_purpose::STANDARD.encode(&input)
175}
176
177pub fn decode_base64<T: AsRef<[u8]>>(input: T) -> anyhow::Result<Vec<u8>> {
178 Ok(base64::engine::general_purpose::STANDARD.decode(input)?)
179}
180
181pub fn usize_encode_le64(v: usize) -> [u8; 8] {
182 u64::to_le_bytes(v as u64)
183}
184
185pub fn usize_decode_le64(v_bytes: &[u8]) -> anyhow::Result<usize> {
186 let v = u64::from_le_bytes(v_bytes.try_into()?);
187 Ok(v as usize)
188}
189
190pub fn write_length_prefixed_json<T: serde::Serialize>(bytes_gatherer: &mut BytesGatherer, value: &T) -> anyhow::Result<()> {
191 let json_bytes = json::struct_to_bytes(value)?;
192 bytes_gatherer.put_u64(json_bytes.len() as u64);
193 bytes_gatherer.put_bytes(json_bytes);
194 Ok(())
195}
196pub fn read_length_prefixed_json<T: serde::de::DeserializeOwned>(bytes: &mut Bytes) -> anyhow::Result<T> {
197 use bytes::Buf;
198
199 if bytes.remaining() < 8 {
200 anyhow::bail!("Invalid buffer: missing json length");
201 }
202
203 let len = bytes.get_u64() as usize;
204
205 if bytes.remaining() < len {
206 anyhow::bail!("Invalid buffer: json data truncated");
207 }
208
209 let json_bytes = bytes.copy_to_bytes(len);
210 json::bytes_to_struct::<T>(&json_bytes)
211}
212
213#[cfg(test)]
214mod tests {
215 #[tokio::test]
216 async fn xor_distance_bits_test() -> anyhow::Result<()> {
217 use crate::tools::tools::leading_agreement_bits_xor;
218
219 let tests = [
220 ("0000", "0000", 16),
222 ("ffff", "ffff", 16),
223 ("1234", "1234", 16),
224 ("abcd", "abcd", 16),
225 ("0000", "ffff", 0),
227 ("0000", "0fff", 4),
228 ("0000", "00ff", 8),
229 ("0000", "000f", 12),
230 ("0000", "efff", 0),
232 ("0000", "7fff", 1),
233 ("0000", "3fff", 2),
234 ("0000", "1fff", 3),
235 ("0000", "0fff", 4),
236 ("0000", "07ff", 5),
237 ("0000", "03ff", 6),
238 ("0000", "01ff", 7),
239 ("0000", "00ff", 8),
240 ("0000", "007f", 9),
241 ("0000", "003f", 10),
242 ("0000", "001f", 11),
243 ("0000", "000f", 12),
244 ("0000", "0007", 13),
245 ("0000", "0003", 14),
246 ("0000", "0001", 15),
247 ("0000", "fff9", 0),
249 ("0000", "0ff9", 4),
250 ("0000", "00f9", 8),
251 ("", "0000", 0),
253 ("00", "0000", 8),
254 ("0000", "000000", 16),
255 ];
256
257 for (a, b, expected) in tests {
258 let a_binary = hex::decode(a)?;
259 let b_binary = hex::decode(b)?;
260 {
261 let distance = leading_agreement_bits_xor(&a_binary, &b_binary);
262 assert_eq!(distance, expected, "Failed for {} and {}. Got {} expected {}.", a, b, distance, expected);
263 }
264 {
265 let distance = leading_agreement_bits_xor(&b_binary, &a_binary);
266 assert_eq!(distance, expected, "Failed for {} and {}. Got {} expected {}.", a, b, distance, expected);
267 }
268 }
269 Ok(())
270 }
271}
272
273pub fn random_element<T>(range: &[T]) -> &T {
274 let index = random_usize_bounded(range.len());
275 &range[index]
276}
277
278pub fn shuffle<T>(source: &mut [T]) {
279 for i in 1..source.len() {
281 let j = random_usize_bounded(i + 1);
282 source.swap(i, j);
283 }
284}
285
286pub struct CustomTimeFormatter {
287 time_provider: Arc<dyn TimeProvider>,
288}
289
290impl CustomTimeFormatter {
291 pub fn new(time_provider: Arc<dyn TimeProvider>) -> Self {
292 Self { time_provider }
293 }
294}
295
296impl FormatTime for CustomTimeFormatter {
297 fn format_time(&self, w: &mut tracing_subscriber::fmt::format::Writer<'_>) -> fmt::Result {
298 write!(w, "{}", self.time_provider.current_time_str())
299 }
300}
301
302pub fn configure_logging() {
303 configure_logging_with_time_provider("trace", Arc::new(RealTimeProvider))
304}
305
306pub fn configure_logging_with_time_provider(level: &str, time_provider: Arc<dyn TimeProvider>) {
307 let filter = format!("{},hyper=off,warp=off,reqwest=off,rustls=off,h2=off,h2=off,html5ever=off,selectors=off,fjall=off,lsm_tree=off,sfa=off,hickory_resolver=off,hickory_proto=off", level);
309 let env_filter = tracing_subscriber::EnvFilter::new(&filter);
310
311 let fmt_layer = tracing_subscriber::fmt::layer().with_timer(CustomTimeFormatter::new(time_provider));
313
314 let registry = tracing_subscriber::registry();
315
316 #[cfg(all(tokio_unstable, not(target_arch = "wasm32")))]
318 registry.with(console_subscriber::spawn());
319
320 registry.with(fmt_layer).with(env_filter).init();
322
323 info!("Logging initialized");
324}
325
326#[cfg(not(target_arch = "wasm32"))]
327pub type TempDirHandle = tempfile::TempDir;
328
329#[cfg(not(target_arch = "wasm32"))]
330pub fn get_temp_dir() -> anyhow::Result<(TempDirHandle, String)> {
331 let mut base = std::env::temp_dir();
332 base.push("hashiverse-temp");
333
334 std::fs::create_dir_all(&base)?;
336
337 let temp_dir = tempfile::Builder::new().prefix("hashiverse-").tempdir_in(&base)?;
338 let temp_dir_path = temp_dir.path().to_str().unwrap().to_string();
339 Ok((temp_dir, temp_dir_path))
340}
341
342#[cfg(target_arch = "wasm32")]
343pub type TempDirHandle = ();
344
345#[cfg(target_arch = "wasm32")]
346pub fn get_temp_dir() -> anyhow::Result<(TempDirHandle, String)> {
347 Ok(((), "".to_string()))
348}
349
350pub fn from_hex_str<T, const T_BYTES: usize>(str: &str, ctor: impl FnOnce([u8; T_BYTES]) -> T) -> anyhow::Result<T> {
351 if str.len() != 2 * T_BYTES {
352 anyhow::bail!("Invalid hex string length: expected {} hex characters ({} bytes), got {} characters.", 2 * T_BYTES, T_BYTES, str.len(),);
353 }
354
355 let decoded = hex::decode(str)?;
357
358 if decoded.len() != T_BYTES {
360 anyhow::bail!("Invalid hex string length: expected {} bytes, got {} bytes", T_BYTES, decoded.len());
361 }
362
363 let mut decoded_bytes = [0u8; T_BYTES];
365 decoded_bytes.copy_from_slice(&decoded);
366
367 Ok(ctor(decoded_bytes))
368}
369
370pub fn spawn_background_task<F>(task: F)
372where
373 F: Future<Output = ()> + Send + 'static,
374{
375 #[cfg(not(target_arch = "wasm32"))]
376 tokio::spawn(task);
377 #[cfg(target_arch = "wasm32")]
378 wasm_bindgen_futures::spawn_local(task);
379}
380