Skip to main content

hashiverse_lib/tools/
tools.rs

1//! # Grab-bag of tiny cross-platform helpers
2//!
3//! Utility functions that don't fit in any of the more focused modules:
4//!
5//! - **Async yielding** ([`yield_now`]) — maps to `tokio::task::yield_now` on native,
6//!   `gloo_timers` on wasm32-unknown, and `tokio::time::sleep(0)` on wasi.
7//! - **Randomness** ([`random_fill_bytes`], [`random_bytes`], [`random_u32`]) — OS RNG
8//!   helpers used by key generation and PoW salt selection.
9//! - **Base64 and hex parsing** — consistent helpers used wherever we need to emit or
10//!   accept textual byte blobs (key persistence, URLs, HTML attributes).
11//! - **Byte reversal** used by server-id PoW hash-to-id mapping.
12//! - **`LeadingAgreementBits`** typedef for the XOR-distance metric used by the DHT.
13//! - **Logging bootstrap** (`tracing_subscriber` initialisation with consistent
14//!   formatting across native and wasm).
15//! - **`Cancellable`-style async helpers** that plug into `CancellationToken`.
16//!
17//! Anything here that grows a meaningful amount of functionality should graduate to
18//! its own module.
19
20use crate::tools::json;
21use crate::tools::time::DurationMillis;
22use crate::tools::time_provider::time_provider::{RealTimeProvider, TimeProvider};
23use crate::tools::BytesGatherer;
24use argon2::password_hash::rand_core::{OsRng, RngCore};
25use base64::Engine;
26use bytes::Bytes;
27use log::info;
28use std::fmt;
29use std::future::Future;
30use std::sync::Arc;
31use tokio_util::sync::CancellationToken;
32use tracing_subscriber::fmt::time::FormatTime;
33use tracing_subscriber::layer::SubscriberExt;
34use tracing_subscriber::util::SubscriberInitExt;
35
36pub type LeadingAgreementBits = i32;
37
38pub async fn yield_now() {
39    #[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
40    {
41        send_wrapper::SendWrapper::new(gloo_timers::future::TimeoutFuture::new(0)).await;
42    }
43    #[cfg(all(target_arch = "wasm32", target_os = "wasi"))]
44    {
45        tokio::time::sleep(std::time::Duration::from_millis(0u64)).await;
46    }
47    #[cfg(not(target_arch = "wasm32"))]
48    {
49        // On native platforms, use Tokio's optimized yield.
50        tokio::task::yield_now().await;
51    }
52}
53
54pub fn random_fill_bytes(dest: &mut [u8]) {
55    OsRng.fill_bytes(dest);
56}
57
58pub fn random_bytes(n: usize) -> Vec<u8> {
59    let mut dest = vec![0u8; n];
60    random_fill_bytes(&mut dest);
61    dest
62}
63
64pub fn reverse_bytes<const N: usize>(bytes: &[u8; N]) -> [u8; N] {
65    let mut result = [0u8; N];
66    for (i, &byte) in bytes.iter().rev().enumerate() {
67        result[i] = byte;
68    }
69    result
70}
71
72pub fn random_u32() -> u32 {
73    OsRng.next_u32()
74}
75
76#[cfg(target_pointer_width = "64")]
77pub fn random_usize() -> usize {
78    OsRng.next_u64() as usize
79}
80
81#[cfg(target_pointer_width = "32")]
82pub fn random_usize() -> usize {
83    OsRng.next_u32() as usize
84}
85
86pub fn random_usize_bounded(upper: usize) -> usize {
87    // Rejection sampling to avoid modulo bias.
88    // We accept values in [0, zone) where zone is the largest multiple of `upper`.
89    let zone = usize::MAX - (usize::MAX % upper);
90
91    loop {
92        let r = random_usize();
93        if r < zone {
94            return r % upper;
95        }
96    }
97}
98
99pub fn random_u8() -> u8 {
100    OsRng.next_u32() as u8
101}
102
103pub fn random_base64(length: usize) -> String {
104    let mut bytes = vec![0u8; length];
105    random_fill_bytes(&mut bytes);
106    encode_base64(bytes)
107}
108
109pub fn are_all_zeros<T: PartialEq + num_traits::Zero>(src: &[T]) -> bool {
110    src.iter().all(|b| *b == T::zero())
111}
112
113pub fn are_all_equal<T: PartialEq>(src1: &[T], src2: &[T]) -> bool {
114    if src1.len() != src2.len() {
115        return false;
116    }
117    src1.iter().zip(src2).all(|(a, b)| a == b)
118}
119
120pub fn count_leading_zero_bits(bytes: &[u8]) -> u8 {
121    let mut count = 0u64;
122
123    for &byte in bytes {
124        if byte == 0 {
125            count += 8;
126            continue;
127        }
128
129        // Count leading zeros in the non-zero byte
130        let mut mask = 0x80; // 10000000 in binary
131        while byte & mask == 0 {
132            count += 1;
133            mask >>= 1;
134        }
135
136        break; // Exit after processing the first non-zero byte
137    }
138
139    if count < 256 { count as u8 } else { 255 }
140}
141
142pub async fn cancellable_sleep_millis(time_provider: &dyn TimeProvider, millis: DurationMillis, cancellation_token: &CancellationToken) {
143    tokio::select! {
144        _ = time_provider.sleep_millis(millis) => {},
145        _ = cancellation_token.cancelled() => {},
146    }
147}
148
149pub fn format_vec<T: std::fmt::Display>(items: &[T]) -> String {
150    format!("[ {} ]", items.iter().map(|item| format!("{}", item)).collect::<Vec<_>>().join(", "))
151}
152
153pub fn leading_agreement_bits_xor(key1: &[u8], key2: &[u8]) -> LeadingAgreementBits {
154    let mut leading_bits_in_agreement: i32 = 0;
155
156    let min_len = std::cmp::min(key1.len(), key2.len());
157    for byte_idx in 0..min_len {
158        let xor = key1[byte_idx] ^ key2[byte_idx];
159
160        // Do we have differing bytes?
161        if xor != 0 {
162            leading_bits_in_agreement += xor.leading_zeros() as LeadingAgreementBits;
163            return leading_bits_in_agreement;
164        }
165        else {
166            leading_bits_in_agreement += 8;
167        }
168    }
169
170    leading_bits_in_agreement
171}
172
173pub fn encode_base64<T: AsRef<[u8]>>(input: T) -> String {
174    base64::engine::general_purpose::STANDARD.encode(&input)
175}
176
177pub fn decode_base64<T: AsRef<[u8]>>(input: T) -> anyhow::Result<Vec<u8>> {
178    Ok(base64::engine::general_purpose::STANDARD.decode(input)?)
179}
180
181pub fn usize_encode_le64(v: usize) -> [u8; 8] {
182    u64::to_le_bytes(v as u64)
183}
184
185pub fn usize_decode_le64(v_bytes: &[u8]) -> anyhow::Result<usize> {
186    let v = u64::from_le_bytes(v_bytes.try_into()?);
187    Ok(v as usize)
188}
189
190pub fn write_length_prefixed_json<T: serde::Serialize>(bytes_gatherer: &mut BytesGatherer, value: &T) -> anyhow::Result<()> {
191    let json_bytes = json::struct_to_bytes(value)?;
192    bytes_gatherer.put_u64(json_bytes.len() as u64);
193    bytes_gatherer.put_bytes(json_bytes);
194    Ok(())
195}
196pub fn read_length_prefixed_json<T: serde::de::DeserializeOwned>(bytes: &mut Bytes) -> anyhow::Result<T> {
197    use bytes::Buf;
198
199    if bytes.remaining() < 8 {
200        anyhow::bail!("Invalid buffer: missing json length");
201    }
202
203    let len = bytes.get_u64() as usize;
204
205    if bytes.remaining() < len {
206        anyhow::bail!("Invalid buffer: json data truncated");
207    }
208
209    let json_bytes = bytes.copy_to_bytes(len);
210    json::bytes_to_struct::<T>(&json_bytes)
211}
212
213#[cfg(test)]
214mod tests {
215    #[tokio::test]
216    async fn xor_distance_bits_test() -> anyhow::Result<()> {
217        use crate::tools::tools::leading_agreement_bits_xor;
218
219        let tests = [
220            // Identical
221            ("0000", "0000", 16),
222            ("ffff", "ffff", 16),
223            ("1234", "1234", 16),
224            ("abcd", "abcd", 16),
225            // MSB
226            ("0000", "ffff", 0),
227            ("0000", "0fff", 4),
228            ("0000", "00ff", 8),
229            ("0000", "000f", 12),
230            // Units
231            ("0000", "efff", 0),
232            ("0000", "7fff", 1),
233            ("0000", "3fff", 2),
234            ("0000", "1fff", 3),
235            ("0000", "0fff", 4),
236            ("0000", "07ff", 5),
237            ("0000", "03ff", 6),
238            ("0000", "01ff", 7),
239            ("0000", "00ff", 8),
240            ("0000", "007f", 9),
241            ("0000", "003f", 10),
242            ("0000", "001f", 11),
243            ("0000", "000f", 12),
244            ("0000", "0007", 13),
245            ("0000", "0003", 14),
246            ("0000", "0001", 15),
247            // MSB + random
248            ("0000", "fff9", 0),
249            ("0000", "0ff9", 4),
250            ("0000", "00f9", 8),
251            // Different lengths
252            ("", "0000", 0),
253            ("00", "0000", 8),
254            ("0000", "000000", 16),
255        ];
256
257        for (a, b, expected) in tests {
258            let a_binary = hex::decode(a)?;
259            let b_binary = hex::decode(b)?;
260            {
261                let distance = leading_agreement_bits_xor(&a_binary, &b_binary);
262                assert_eq!(distance, expected, "Failed for {} and {}.  Got {} expected {}.", a, b, distance, expected);
263            }
264            {
265                let distance = leading_agreement_bits_xor(&b_binary, &a_binary);
266                assert_eq!(distance, expected, "Failed for {} and {}.  Got {} expected {}.", a, b, distance, expected);
267            }
268        }
269        Ok(())
270    }
271}
272
273pub fn random_element<T>(range: &[T]) -> &T {
274    let index = random_usize_bounded(range.len());
275    &range[index]
276}
277
278pub fn shuffle<T>(source: &mut [T]) {
279    // Fisher–Yates / Knuth shuffle (uniform)
280    for i in 1..source.len() {
281        let j = random_usize_bounded(i + 1);
282        source.swap(i, j);
283    }
284}
285
286pub struct CustomTimeFormatter {
287    time_provider: Arc<dyn TimeProvider>,
288}
289
290impl CustomTimeFormatter {
291    pub fn new(time_provider: Arc<dyn TimeProvider>) -> Self {
292        Self { time_provider }
293    }
294}
295
296impl FormatTime for CustomTimeFormatter {
297    fn format_time(&self, w: &mut tracing_subscriber::fmt::format::Writer<'_>) -> fmt::Result {
298        write!(w, "{}", self.time_provider.current_time_str())
299    }
300}
301
302pub fn configure_logging() {
303    configure_logging_with_time_provider("trace", Arc::new(RealTimeProvider))
304}
305
306pub fn configure_logging_with_time_provider(level: &str, time_provider: Arc<dyn TimeProvider>) {
307    // The filter
308    let filter = format!("{},hyper=off,warp=off,reqwest=off,rustls=off,h2=off,h2=off,html5ever=off,selectors=off,fjall=off,lsm_tree=off,sfa=off,hickory_resolver=off,hickory_proto=off", level);
309    let env_filter = tracing_subscriber::EnvFilter::new(&filter);
310
311    // Prepare the Standard Logging Layer
312    let fmt_layer = tracing_subscriber::fmt::layer().with_timer(CustomTimeFormatter::new(time_provider));
313
314    let registry = tracing_subscriber::registry();
315
316    // Prepare the Console Layer (Conditional) - we only enable this if the 'tokio_unstable' cfg is present and we are not WASM
317    #[cfg(all(tokio_unstable, not(target_arch = "wasm32")))]
318    registry.with(console_subscriber::spawn());
319
320    // Register everything
321    registry.with(fmt_layer).with(env_filter).init();
322
323    info!("Logging initialized");
324}
325
326#[cfg(not(target_arch = "wasm32"))]
327pub type TempDirHandle = tempfile::TempDir;
328
329#[cfg(not(target_arch = "wasm32"))]
330pub fn get_temp_dir() -> anyhow::Result<(TempDirHandle, String)> {
331    let mut base = std::env::temp_dir();
332    base.push("hashiverse-temp");
333
334    // Ensure the base directory exists
335    std::fs::create_dir_all(&base)?;
336
337    let temp_dir = tempfile::Builder::new().prefix("hashiverse-").tempdir_in(&base)?;
338    let temp_dir_path = temp_dir.path().to_str().unwrap().to_string();
339    Ok((temp_dir, temp_dir_path))
340}
341
342#[cfg(target_arch = "wasm32")]
343pub type TempDirHandle = ();
344
345#[cfg(target_arch = "wasm32")]
346pub fn get_temp_dir() -> anyhow::Result<(TempDirHandle, String)> {
347    Ok(((), "".to_string()))
348}
349
350pub fn from_hex_str<T, const T_BYTES: usize>(str: &str, ctor: impl FnOnce([u8; T_BYTES]) -> T) -> anyhow::Result<T> {
351    if str.len() != 2 * T_BYTES {
352        anyhow::bail!("Invalid hex string length: expected {} hex characters ({} bytes), got {} characters.", 2 * T_BYTES, T_BYTES, str.len(),);
353    }
354
355    // Try to decode the hex string
356    let decoded = hex::decode(str)?;
357
358    // Check if the decoded bytes are exactly xxx bytes
359    if decoded.len() != T_BYTES {
360        anyhow::bail!("Invalid hex string length: expected {} bytes, got {} bytes", T_BYTES, decoded.len());
361    }
362
363    // Convert Vec<u8> to [u8; xxx]
364    let mut decoded_bytes = [0u8; T_BYTES];
365    decoded_bytes.copy_from_slice(&decoded);
366
367    Ok(ctor(decoded_bytes))
368}
369
370/// Spawn a background async task, using the appropriate runtime for the current target.
371pub fn spawn_background_task<F>(task: F)
372where
373    F: Future<Output = ()> + Send + 'static,
374{
375    #[cfg(not(target_arch = "wasm32"))]
376    tokio::spawn(task);
377    #[cfg(target_arch = "wasm32")]
378    wasm_bindgen_futures::spawn_local(task);
379}
380