Skip to main content

zerotrie/byte_phf/
mod.rs

1// This file is part of ICU4X. For terms of use, please see the file
2// called LICENSE at the top level of the ICU4X source tree
3// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ).
4
5#![allow(rustdoc::private_intra_doc_links)] // doc(hidden) module
6
7//! # Byte Perfect Hash Function Internals
8//!
9//! This module contains a perfect hash function (PHF) designed for a fast, compact perfect
10//! hash over 1 to 256 nodes (bytes).
11//!
12//! The PHF uses the following variables:
13//!
14//! 1. A single parameter `p`, which is 0 in about 98% of cases.
15//! 2. A list of `N` parameters `q_t`, one per _bucket_
16//! 3. The `N` keys in an arbitrary order determined by the PHF
17//!
18//! Reading a `key` from the PHF uses the following algorithm:
19//!
20//! 1. Let `t`, the bucket index, be `f1(key, p)`.
21//! 2. Let `i`, the key index, be `f2(key, q_t)`.
22//! 3. If `key == k_i`, return `Some(i)`; else return `None`.
23//!
24//! The functions [`f1`] and [`f2`] are internal to the PHF but should remain stable across
25//! serialization versions of `ZeroTrie`. They are very fast, constant-time operations as long
26//! as `p` <= [`P_FAST_MAX`] and `q` <= [`Q_FAST_MAX`]. In practice, nearly 100% of parameter
27//! values are in the fast range.
28//!
29//! ```
30//! use zerotrie::_internal::PerfectByteHashMap;
31//!
32//! let phf_example_bytes = [
33//!     // `p` parameter
34//!     1, // `q` parameters, one for each of the N buckets
35//!     0, 0, 1, 1, // Exact keys to be compared with the input
36//!     b'e', b'a', b'c', b'g',
37//! ];
38//!
39//! let phf = PerfectByteHashMap::from_bytes(&phf_example_bytes);
40//!
41//! // The PHF returns the index of the key or `None` if not found.
42//! assert_eq!(phf.get(b'a'), Some(1));
43//! assert_eq!(phf.get(b'b'), None);
44//! assert_eq!(phf.get(b'c'), Some(2));
45//! assert_eq!(phf.get(b'd'), None);
46//! assert_eq!(phf.get(b'e'), Some(0));
47//! assert_eq!(phf.get(b'f'), None);
48//! assert_eq!(phf.get(b'g'), Some(3));
49//! ```
50
51use crate::helpers::*;
52
53#[cfg(feature = "alloc")]
54mod builder;
55#[cfg(feature = "alloc")]
56mod cached_owned;
57
58#[cfg(feature = "alloc")]
59pub use cached_owned::PerfectByteHashMapCacheOwned;
60
61/// The cutoff for the fast version of [`f1`].
62#[cfg(feature = "alloc")] // used in the builder code
63const P_FAST_MAX: u8 = 95;
64
65/// The cutoff for the fast version of [`f2`].
66const Q_FAST_MAX: u8 = 95;
67
68/// The maximum allowable value of `p`. This could be raised if found to be necessary.
69/// Values exceeding `P_FAST_MAX` could use a different `p` algorithm by modifying [`f1`].
70#[cfg(feature = "alloc")] // used in the builder code
71const P_REAL_MAX: u8 = P_FAST_MAX;
72
73/// The maximum allowable value of `q`. This could be raised if found to be necessary.
74#[cfg(feature = "alloc")] // used in the builder code
75const Q_REAL_MAX: u8 = 127;
76
77/// Calculates the function `f1` for the PHF. For the exact formula, please read the code.
78///
79/// When `p == 0`, the operation is a simple modulus.
80///
81/// The argument `n` is used only for taking the modulus so that the return value is
82/// in the range `[0, n)`.
83///
84/// # Examples
85///
86/// ```
87/// use zerotrie::_internal::f1;
88/// const N: u8 = 10;
89///
90/// // With p = 0:
91/// assert_eq!(0, f1(0, 0, N));
92/// assert_eq!(1, f1(1, 0, N));
93/// assert_eq!(2, f1(2, 0, N));
94/// assert_eq!(9, f1(9, 0, N));
95/// assert_eq!(0, f1(10, 0, N));
96/// assert_eq!(1, f1(11, 0, N));
97/// assert_eq!(2, f1(12, 0, N));
98/// assert_eq!(9, f1(19, 0, N));
99///
100/// // With p = 1:
101/// assert_eq!(1, f1(0, 1, N));
102/// assert_eq!(0, f1(1, 1, N));
103/// assert_eq!(2, f1(2, 1, N));
104/// assert_eq!(2, f1(9, 1, N));
105/// assert_eq!(4, f1(10, 1, N));
106/// assert_eq!(5, f1(11, 1, N));
107/// assert_eq!(1, f1(12, 1, N));
108/// assert_eq!(7, f1(19, 1, N));
109/// ```
110#[inline]
111pub fn f1(byte: u8, p: u8, n: u8) -> u8 {
112    if n == 0 {
113        byte
114    } else if p == 0 {
115        byte % n
116    } else {
117        // `p` always uses the below constant-time operation. If needed, we
118        // could add some other operation here with `p > P_FAST_MAX` to solve
119        // difficult cases if the need arises.
120        let result = byte ^ p ^ byte.wrapping_shr(p as u32);
121        result % n
122    }
123}
124
125/// Calculates the function `f2` for the PHF. For the exact formula, please read the code.
126///
127/// When `q == 0`, the operation is a simple modulus.
128///
129/// The argument `n` is used only for taking the modulus so that the return value is
130/// in the range `[0, n)`.
131///
132/// # Examples
133///
134/// ```
135/// use zerotrie::_internal::f2;
136/// const N: u8 = 10;
137///
138/// // With q = 0:
139/// assert_eq!(0, f2(0, 0, N));
140/// assert_eq!(1, f2(1, 0, N));
141/// assert_eq!(2, f2(2, 0, N));
142/// assert_eq!(9, f2(9, 0, N));
143/// assert_eq!(0, f2(10, 0, N));
144/// assert_eq!(1, f2(11, 0, N));
145/// assert_eq!(2, f2(12, 0, N));
146/// assert_eq!(9, f2(19, 0, N));
147///
148/// // With q = 1:
149/// assert_eq!(1, f2(0, 1, N));
150/// assert_eq!(0, f2(1, 1, N));
151/// assert_eq!(3, f2(2, 1, N));
152/// assert_eq!(8, f2(9, 1, N));
153/// assert_eq!(1, f2(10, 1, N));
154/// assert_eq!(0, f2(11, 1, N));
155/// assert_eq!(3, f2(12, 1, N));
156/// assert_eq!(8, f2(19, 1, N));
157/// ```
158#[inline]
159pub fn f2(byte: u8, q: u8, n: u8) -> u8 {
160    if n == 0 {
161        return byte;
162    }
163    let mut result = byte ^ q;
164    // In almost all cases, the PHF works with the above constant-time operation.
165    // However, to crack a few difficult cases, we fall back to the linear-time
166    // operation shown below.
167    for _ in Q_FAST_MAX..q {
168        result = result ^ (result << 1) ^ (result >> 1);
169    }
170    result % n
171}
172
173/// A constant-time map from bytes to unique indices.
174///
175/// Uses a perfect hash function (see module-level documentation). Does not support mutation.
176///
177/// Standard layout: P, N bytes of Q, N bytes of expected keys
178#[derive(#[automatically_derived]
impl<Store: ::core::fmt::Debug + ?Sized> ::core::fmt::Debug for
    PerfectByteHashMap<Store> {
    #[inline]
    fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
        ::core::fmt::Formatter::debug_tuple_field1_finish(f,
            "PerfectByteHashMap", &&self.0)
    }
}Debug, #[automatically_derived]
impl<Store: ::core::cmp::PartialEq + ?Sized> ::core::cmp::PartialEq for
    PerfectByteHashMap<Store> {
    #[inline]
    fn eq(&self, other: &PerfectByteHashMap<Store>) -> bool {
        self.0 == other.0
    }
}PartialEq, #[automatically_derived]
impl<Store: ::core::cmp::Eq + ?Sized> ::core::cmp::Eq for
    PerfectByteHashMap<Store> {
    #[inline]
    #[doc(hidden)]
    #[coverage(off)]
    fn assert_fields_are_eq(&self) {
        let _: ::core::cmp::AssertParamIsEq<Store>;
    }
}Eq)]
179#[repr(transparent)]
180pub struct PerfectByteHashMap<Store: ?Sized>(Store);
181
182impl<Store> PerfectByteHashMap<Store> {
183    /// Creates an instance from a pre-existing store. See [`Self::as_bytes`].
184    #[inline]
185    pub fn from_store(store: Store) -> Self {
186        Self(store)
187    }
188}
189
190impl<Store> PerfectByteHashMap<Store>
191where
192    Store: AsRef<[u8]> + ?Sized,
193{
194    /// Gets the usize for the given byte, or `None` if it is not in the map.
195    pub fn get(&self, key: u8) -> Option<usize> {
196        let (p, buffer) = self.0.as_ref().split_first()?;
197        // Note: there are N buckets followed by N keys
198        let n_usize = buffer.len() / 2;
199        if n_usize == 0 {
200            return None;
201        }
202        let n = n_usize as u8;
203        let (qq, eks) = buffer.debug_split_at(n_usize);
204        if true {
    match (&qq.len(), &eks.len()) {
        (left_val, right_val) => {
            if !(*left_val == *right_val) {
                let kind = ::core::panicking::AssertKind::Eq;
                ::core::panicking::assert_failed(kind, &*left_val,
                    &*right_val, ::core::option::Option::None);
            }
        }
    };
};debug_assert_eq!(qq.len(), eks.len());
205        let l1 = f1(key, *p, n) as usize;
206        let q = match qq.get(l1) {
    Some(x) => x,
    None => {
        if true {
            if !false {
                {
                    ::core::panicking::panic_fmt(format_args!("invalid trie"));
                }
            };
        };
        return None;
    }
}debug_unwrap!(qq.get(l1), return None);
207        let l2 = f2(key, *q, n) as usize;
208        let ek = match eks.get(l2) {
    Some(x) => x,
    None => {
        if true {
            if !false {
                {
                    ::core::panicking::panic_fmt(format_args!("invalid trie"));
                }
            };
        };
        return None;
    }
}debug_unwrap!(eks.get(l2), return None);
209        if *ek == key {
210            Some(l2)
211        } else {
212            None
213        }
214    }
215    /// This is called `num_items` because `len` is ambiguous: it could refer
216    /// to the number of items or the number of bytes.
217    pub fn num_items(&self) -> usize {
218        self.0.as_ref().len() / 2
219    }
220    /// Get an iterator over the keys in the order in which they are stored in the map.
221    pub fn keys(&self) -> &[u8] {
222        let n = self.num_items();
223        self.0.as_ref().debug_split_at(1 + n).1
224    }
225    /// Diagnostic function that returns `p` and the maximum value of `q`
226    #[cfg(test)]
227    pub fn p_qmax(&self) -> Option<(u8, u8)> {
228        let (p, buffer) = self.0.as_ref().split_first()?;
229        let n = buffer.len() / 2;
230        if n == 0 {
231            return None;
232        }
233        let (qq, _) = buffer.debug_split_at(n);
234        Some((*p, *qq.iter().max().unwrap()))
235    }
236    /// Returns the map as bytes. The map can be recovered with [`Self::from_store`]
237    /// or [`Self::from_bytes`].
238    pub fn as_bytes(&self) -> &[u8] {
239        self.0.as_ref()
240    }
241
242    #[cfg(all(feature = "alloc", test))]
243    pub(crate) fn check(&self) -> Result<(), (&'static str, u8)> {
244        use alloc::vec;
245        let len = self.num_items();
246        let mut seen = vec![false; len];
247        for b in 0..=255u8 {
248            let get_result = self.get(b);
249            if self.keys().contains(&b) {
250                let i = get_result.ok_or(("expected to find", b))?;
251                if seen[i] {
252                    return Err(("seen", b));
253                }
254                seen[i] = true;
255            } else if get_result.is_some() {
256                return Err(("did not expect to find", b));
257            }
258        }
259        Ok(())
260    }
261}
262
263impl PerfectByteHashMap<[u8]> {
264    /// Creates an instance from pre-existing bytes. See [`Self::as_bytes`].
265    #[inline]
266    #[allow(unsafe_code)] // transparent newtype casts are documented
267    pub fn from_bytes(bytes: &[u8]) -> &Self {
268        // Safety: Self is repr(transparent) over [u8]
269        unsafe { &*(bytes as *const [u8] as *const Self) }
270    }
271}
272
273impl<Store> PerfectByteHashMap<Store>
274where
275    Store: AsRef<[u8]> + ?Sized,
276{
277    /// Converts from `PerfectByteHashMap<AsRef<[u8]>>` to `&PerfectByteHashMap<[u8]>`
278    #[inline]
279    pub fn as_borrowed(&self) -> &PerfectByteHashMap<[u8]> {
280        PerfectByteHashMap::from_bytes(self.0.as_ref())
281    }
282}
283
284#[cfg(all(test, feature = "alloc"))]
285mod tests {
286    use super::*;
287    use alloc::vec::Vec;
288    extern crate std;
289
290    fn random_alphanums(seed: u64, len: usize) -> Vec<u8> {
291        use rand::seq::SliceRandom;
292        use rand::SeedableRng;
293
294        let mut bytes: Vec<u8> =
295            b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789".into();
296        let mut rng = rand_pcg::Lcg64Xsh32::seed_from_u64(seed);
297        bytes.partial_shuffle(&mut rng, len).0.into()
298    }
299
300    #[test]
301    fn test_smaller() {
302        let mut count_by_p = [0; 256];
303        let mut count_by_qmax = [0; 256];
304        for len in 1..16 {
305            for seed in 0..150 {
306                let keys = random_alphanums(seed, len);
307                let keys_str = core::str::from_utf8(&keys).unwrap();
308                let computed = PerfectByteHashMap::try_new(&keys).expect(keys_str);
309                computed
310                    .check()
311                    .unwrap_or_else(|_| panic!("{}", std::str::from_utf8(&keys).expect(keys_str)));
312                let (p, qmax) = computed.p_qmax().unwrap();
313                count_by_p[p as usize] += 1;
314                count_by_qmax[qmax as usize] += 1;
315            }
316        }
317        std::println!("count_by_p (smaller): {count_by_p:?}");
318        std::println!("count_by_qmax (smaller): {count_by_qmax:?}");
319        let count_fastq = count_by_qmax[0..=Q_FAST_MAX as usize].iter().sum::<usize>();
320        let count_slowq = count_by_qmax[Q_FAST_MAX as usize + 1..]
321            .iter()
322            .sum::<usize>();
323        std::println!("fastq/slowq: {count_fastq}/{count_slowq}");
324        // Assert that 99% of cases resolve to the fast hash
325        assert!(count_fastq >= count_slowq * 100);
326    }
327
328    #[test]
329    fn test_larger() {
330        let mut count_by_p = [0; 256];
331        let mut count_by_qmax = [0; 256];
332        for len in 16..60 {
333            for seed in 0..75 {
334                let keys = random_alphanums(seed, len);
335                let keys_str = core::str::from_utf8(&keys).unwrap();
336                let computed = PerfectByteHashMap::try_new(&keys).expect(keys_str);
337                computed
338                    .check()
339                    .unwrap_or_else(|_| panic!("{}", std::str::from_utf8(&keys).expect(keys_str)));
340                let (p, qmax) = computed.p_qmax().unwrap();
341                count_by_p[p as usize] += 1;
342                count_by_qmax[qmax as usize] += 1;
343            }
344        }
345        std::println!("count_by_p (larger): {count_by_p:?}");
346        std::println!("count_by_qmax (larger): {count_by_qmax:?}");
347        let count_fastq = count_by_qmax[0..=Q_FAST_MAX as usize].iter().sum::<usize>();
348        let count_slowq = count_by_qmax[Q_FAST_MAX as usize + 1..]
349            .iter()
350            .sum::<usize>();
351        std::println!("fastq/slowq: {count_fastq}/{count_slowq}");
352        // Assert that 99% of cases resolve to the fast hash
353        assert!(count_fastq >= count_slowq * 100);
354    }
355
356    #[test]
357    fn test_hard_cases() {
358        let keys = [
359            0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
360            24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
361            46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67,
362            68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
363            90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108,
364            109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
365            126, 195, 196,
366        ];
367
368        let computed = PerfectByteHashMap::try_new(&keys).unwrap();
369        let (p, qmax) = computed.p_qmax().unwrap();
370        assert_eq!(p, 69);
371        assert_eq!(qmax, 67);
372    }
373
374    #[test]
375    fn test_build_read_small() {
376        #[derive(Debug)]
377        struct TestCase<'a> {
378            keys: &'a str,
379            expected: &'a [u8],
380            reordered_keys: &'a str,
381        }
382        let cases = [
383            TestCase {
384                keys: "ab",
385                expected: &[0, 0, 0, b'b', b'a'],
386                reordered_keys: "ba",
387            },
388            TestCase {
389                keys: "abc",
390                expected: &[0, 0, 0, 0, b'c', b'a', b'b'],
391                reordered_keys: "cab",
392            },
393            TestCase {
394                // Note: splitting "a" and "c" into different buckets requires the heavier hash
395                // function because the difference between "a" and "c" is the period (2).
396                keys: "ac",
397                expected: &[1, 0, 1, b'c', b'a'],
398                reordered_keys: "ca",
399            },
400            TestCase {
401                keys: "aceg",
402                expected: &[1, 0, 0, 1, 1, b'e', b'a', b'c', b'g'],
403                reordered_keys: "eacg",
404            },
405            TestCase {
406                keys: "abd",
407                expected: &[0, 0, 1, 3, b'a', b'b', b'd'],
408                reordered_keys: "abd",
409            },
410            TestCase {
411                keys: "def",
412                expected: &[0, 0, 0, 0, b'f', b'd', b'e'],
413                reordered_keys: "fde",
414            },
415            TestCase {
416                keys: "fi",
417                expected: &[0, 0, 0, b'f', b'i'],
418                reordered_keys: "fi",
419            },
420            TestCase {
421                keys: "gh",
422                expected: &[0, 0, 0, b'h', b'g'],
423                reordered_keys: "hg",
424            },
425            TestCase {
426                keys: "lm",
427                expected: &[0, 0, 0, b'l', b'm'],
428                reordered_keys: "lm",
429            },
430            TestCase {
431                // Note: "a" and "q" (0x61 and 0x71) are very hard to split; only a handful of
432                // hash function crates can get them into separate buckets.
433                keys: "aq",
434                expected: &[4, 0, 1, b'a', b'q'],
435                reordered_keys: "aq",
436            },
437            TestCase {
438                keys: "xy",
439                expected: &[0, 0, 0, b'x', b'y'],
440                reordered_keys: "xy",
441            },
442            TestCase {
443                keys: "xyz",
444                expected: &[0, 0, 0, 0, b'x', b'y', b'z'],
445                reordered_keys: "xyz",
446            },
447            TestCase {
448                keys: "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
449                expected: &[
450                    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 10, 12, 16, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 16,
451                    16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
452                    2, 0, 7, 104, 105, 106, 107, 108, 109, 110, 111, 112, 117, 118, 119, 68, 69,
453                    70, 113, 114, 65, 66, 67, 120, 121, 122, 115, 72, 73, 74, 71, 80, 81, 82, 83,
454                    84, 85, 86, 87, 88, 89, 90, 75, 76, 77, 78, 79, 103, 97, 98, 99, 116, 100, 102,
455                    101,
456                ],
457                reordered_keys: "hijklmnopuvwDEFqrABCxyzsHIJGPQRSTUVWXYZKLMNOgabctdfe",
458            },
459            TestCase {
460                keys: "abcdefghij",
461                expected: &[
462                    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 100, 101, 102, 103, 104, 105, 106, 97, 98, 99,
463                ],
464                reordered_keys: "defghijabc",
465            },
466            TestCase {
467                // This is a small case that resolves to the slow hasher
468                keys: "Jbej",
469                expected: &[2, 0, 0, 102, 0, b'j', b'e', b'b', b'J'],
470                reordered_keys: "jebJ",
471            },
472            TestCase {
473                // This is another small case that resolves to the slow hasher
474                keys: "JFNv",
475                expected: &[1, 98, 0, 2, 0, b'J', b'F', b'N', b'v'],
476                reordered_keys: "JFNv",
477            },
478        ];
479        for cas in cases {
480            let computed = PerfectByteHashMap::try_new(cas.keys.as_bytes()).expect(cas.keys);
481            assert_eq!(computed.as_bytes(), cas.expected, "{cas:?}");
482            assert_eq!(computed.keys(), cas.reordered_keys.as_bytes(), "{cas:?}");
483            computed.check().expect(cas.keys);
484        }
485    }
486}