Skip to content

Commit

Permalink
Merge pull request #37 from raphlinus/master
Browse files Browse the repository at this point in the history
Use minimal perfect hashing for lookups
  • Loading branch information
Manishearth authored Apr 16, 2019
2 parents f24cb8a + 40f9ba6 commit 7c23cc9
Show file tree
Hide file tree
Showing 8 changed files with 21,617 additions and 10,783 deletions.
163 changes: 100 additions & 63 deletions scripts/unicode.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# Since this should not require frequent updates, we just store this
# out-of-line and check the unicode.rs file into git.
import collections
import requests
import urllib.request

UNICODE_VERSION = "9.0.0"
UCD_URL = "https://www.unicode.org/Public/%s/ucd/" % UNICODE_VERSION
Expand Down Expand Up @@ -68,9 +68,9 @@ def __init__(self):

def stats(name, table):
count = sum(len(v) for v in table.values())
print "%s: %d chars => %d decomposed chars" % (name, len(table), count)
print("%s: %d chars => %d decomposed chars" % (name, len(table), count))

print "Decomposition table stats:"
print("Decomposition table stats:")
stats("Canonical decomp", self.canon_decomp)
stats("Compatible decomp", self.compat_decomp)
stats("Canonical fully decomp", self.canon_fully_decomp)
Expand All @@ -79,8 +79,8 @@ def stats(name, table):
self.ss_leading, self.ss_trailing = self._compute_stream_safe_tables()

def _fetch(self, filename):
resp = requests.get(UCD_URL + filename)
return resp.text
resp = urllib.request.urlopen(UCD_URL + filename)
return resp.read().decode('utf-8')

def _load_unicode_data(self):
self.combining_classes = {}
Expand Down Expand Up @@ -234,7 +234,7 @@ def _decompose(char_int, compatible):
# need to store their overlap when they agree. When they don't agree,
# store the decomposition in the compatibility table since we'll check
# that first when normalizing to NFKD.
assert canon_fully_decomp <= compat_fully_decomp
assert set(canon_fully_decomp) <= set(compat_fully_decomp)

for ch in set(canon_fully_decomp) & set(compat_fully_decomp):
if canon_fully_decomp[ch] == compat_fully_decomp[ch]:
Expand Down Expand Up @@ -284,27 +284,37 @@ def _compute_stream_safe_tables(self):

return leading_nonstarters, trailing_nonstarters

hexify = lambda c: hex(c)[2:].upper().rjust(4, '0')
hexify = lambda c: '{:04X}'.format(c)

def gen_combining_class(combining_classes, out):
out.write("#[inline]\n")
out.write("pub fn canonical_combining_class(c: char) -> u8 {\n")
out.write(" match c {\n")

for char, combining_class in sorted(combining_classes.items()):
out.write(" '\u{%s}' => %s,\n" % (hexify(char), combining_class))
def gen_mph_data(name, d, kv_type, kv_callback):
(salt, keys) = minimal_perfect_hash(d)
out.write("pub(crate) const %s_SALT: &[u16] = &[\n" % name.upper())
for s in salt:
out.write(" 0x{:x},\n".format(s))
out.write("];\n")
out.write("pub(crate) const {}_KV: &[{}] = &[\n".format(name.upper(), kv_type))
for k in keys:
out.write(" {},\n".format(kv_callback(k)))
out.write("];\n\n")

out.write(" _ => 0,\n")
out.write(" }\n")
out.write("}\n")
def gen_combining_class(combining_classes, out):
gen_mph_data('canonical_combining_class', combining_classes, 'u32',
lambda k: "0x{:X}".format(int(combining_classes[k]) | (k << 8)))

def gen_composition_table(canon_comp, out):
out.write("#[inline]\n")
out.write("pub fn composition_table(c1: char, c2: char) -> Option<char> {\n")
table = {}
for (c1, c2), c3 in canon_comp.items():
if c1 < 0x10000 and c2 < 0x10000:
table[(c1 << 16) | c2] = c3
(salt, keys) = minimal_perfect_hash(table)
gen_mph_data('COMPOSITION_TABLE', table, '(u32, char)',
lambda k: "(0x%s, '\\u{%s}')" % (hexify(k), hexify(table[k])))

out.write("pub(crate) fn composition_table_astral(c1: char, c2: char) -> Option<char> {\n")
out.write(" match (c1, c2) {\n")

for (c1, c2), c3 in sorted(canon_comp.items()):
out.write(" ('\u{%s}', '\u{%s}') => Some('\u{%s}'),\n" % (hexify(c1), hexify(c2), hexify(c3)))
if c1 >= 0x10000 and c2 >= 0x10000:
out.write(" ('\\u{%s}', '\\u{%s}') => Some('\\u{%s}'),\n" % (hexify(c1), hexify(c2), hexify(c3)))

out.write(" _ => None,\n")
out.write(" }\n")
Expand All @@ -313,23 +323,9 @@ def gen_composition_table(canon_comp, out):
def gen_decomposition_tables(canon_decomp, compat_decomp, out):
tables = [(canon_decomp, 'canonical'), (compat_decomp, 'compatibility')]
for table, name in tables:
out.write("#[inline]\n")
out.write("pub fn %s_fully_decomposed(c: char) -> Option<&'static [char]> {\n" % name)
# The "Some" constructor is around the match statement here, because
# putting it into the individual arms would make the item_bodies
# checking of rustc takes almost twice as long, and it's already pretty
# slow because of the huge number of match arms and the fact that there
# is a borrow inside each arm
out.write(" Some(match c {\n")

for char, chars in sorted(table.items()):
d = ", ".join("'\u{%s}'" % hexify(c) for c in chars)
out.write(" '\u{%s}' => &[%s],\n" % (hexify(char), d))

out.write(" _ => return None,\n")
out.write(" })\n")
out.write("}\n")
out.write("\n")
gen_mph_data(name + '_decomposed', table, "(u32, &'static [char])",
lambda k: "(0x{:x}, &[{}])".format(k,
", ".join("'\\u{%s}'" % hexify(c) for c in table[k])))

def gen_qc_match(prop_table, out):
out.write(" match c {\n")
Expand Down Expand Up @@ -371,40 +367,25 @@ def gen_nfkd_qc(prop_tables, out):
out.write("}\n")

def gen_combining_mark(general_category_mark, out):
out.write("#[inline]\n")
out.write("pub fn is_combining_mark(c: char) -> bool {\n")
out.write(" match c {\n")

for char in general_category_mark:
out.write(" '\u{%s}' => true,\n" % hexify(char))

out.write(" _ => false,\n")
out.write(" }\n")
out.write("}\n")
gen_mph_data('combining_mark', general_category_mark, 'u32',
lambda k: '0x{:04x}'.format(k))

def gen_stream_safe(leading, trailing, out):
# This could be done as a hash but the table is very small.
out.write("#[inline]\n")
out.write("pub fn stream_safe_leading_nonstarters(c: char) -> usize {\n")
out.write(" match c {\n")

for char, num_leading in leading.items():
out.write(" '\u{%s}' => %d,\n" % (hexify(char), num_leading))
for char, num_leading in sorted(leading.items()):
out.write(" '\\u{%s}' => %d,\n" % (hexify(char), num_leading))

out.write(" _ => 0,\n")
out.write(" }\n")
out.write("}\n")
out.write("\n")

out.write("#[inline]\n")
out.write("pub fn stream_safe_trailing_nonstarters(c: char) -> usize {\n")
out.write(" match c {\n")

for char, num_trailing in trailing.items():
out.write(" '\u{%s}' => %d,\n" % (hexify(char), num_trailing))

out.write(" _ => 0,\n")
out.write(" }\n")
out.write("}\n")
gen_mph_data('trailing_nonstarters', trailing, 'u32',
lambda k: "0x{:X}".format(int(trailing[k]) | (k << 8)))

def gen_tests(tests, out):
out.write("""#[derive(Debug)]
Expand All @@ -419,7 +400,7 @@ def gen_tests(tests, out):
""")

out.write("pub const NORMALIZATION_TESTS: &[NormalizationTest] = &[\n")
str_literal = lambda s: '"%s"' % "".join("\u{%s}" % c for c in s)
str_literal = lambda s: '"%s"' % "".join("\\u{%s}" % c for c in s)

for test in tests:
out.write(" NormalizationTest {\n")
Expand All @@ -432,9 +413,65 @@ def gen_tests(tests, out):

out.write("];\n")

# Guaranteed to be less than n.
def my_hash(x, salt, n):
# This is hash based on the theory that multiplication is efficient
mask_32 = 0xffffffff
y = ((x + salt) * 2654435769) & mask_32
y ^= (x * 0x31415926) & mask_32
return (y * n) >> 32

# Compute minimal perfect hash function, d can be either a dict or list of keys.
def minimal_perfect_hash(d):
n = len(d)
buckets = dict((h, []) for h in range(n))
for key in d:
h = my_hash(key, 0, n)
buckets[h].append(key)
bsorted = [(len(buckets[h]), h) for h in range(n)]
bsorted.sort(reverse = True)
claimed = [False] * n
salts = [0] * n
keys = [0] * n
for (bucket_size, h) in bsorted:
# Note: the traditional perfect hashing approach would also special-case
# bucket_size == 1 here and assign any empty slot, rather than iterating
# until rehash finds an empty slot. But we're not doing that so we can
# avoid the branch.
if bucket_size == 0:
break
else:
for salt in range(1, 32768):
rehashes = [my_hash(key, salt, n) for key in buckets[h]]
# Make sure there are no rehash collisions within this bucket.
if all(not claimed[hash] for hash in rehashes):
if len(set(rehashes)) < bucket_size:
continue
salts[h] = salt
for key in buckets[h]:
rehash = my_hash(key, salt, n)
claimed[rehash] = True
keys[rehash] = key
break
if salts[h] == 0:
print("minimal perfect hashing failed")
# Note: if this happens (because of unfortunate data), then there are
# a few things that could be done. First, the hash function could be
# tweaked. Second, the bucket order could be scrambled (especially the
# singletons). Right now, the buckets are sorted, which has the advantage
# of being deterministic.
#
# As a more extreme approach, the singleton bucket optimization could be
# applied (give the direct address for singleton buckets, rather than
# relying on a rehash). That is definitely the more standard approach in
# the minimal perfect hashing literature, but in testing the branch was a
# significant slowdown.
exit(1)
return (salts, keys)

if __name__ == '__main__':
data = UnicodeData()
with open("tables.rs", "w") as out:
with open("tables.rs", "w", newline = "\n") as out:
out.write(PREAMBLE)
out.write("use quick_check::IsNormalized;\n")
out.write("use quick_check::IsNormalized::*;\n")
Expand Down Expand Up @@ -470,6 +507,6 @@ def gen_tests(tests, out):
gen_stream_safe(data.ss_leading, data.ss_trailing, out)
out.write("\n")

with open("normalization_tests.rs", "w") as out:
with open("normalization_tests.rs", "w", newline = "\n") as out:
out.write(PREAMBLE)
gen_tests(data.norm_tests, out)
8 changes: 3 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ pub use stream_safe::StreamSafe;
use std::str::Chars;

mod decompose;
mod lookups;
mod normalize;
mod perfect_hash;
mod recompose;
mod quick_check;
mod stream_safe;
Expand All @@ -80,11 +82,7 @@ mod normalization_tests;
pub mod char {
pub use normalize::{decompose_canonical, decompose_compatible, compose};

/// Look up the canonical combining class of a character.
pub use tables::canonical_combining_class;

/// Return whether the given character is a combining mark (`General_Category=Mark`)
pub use tables::is_combining_mark;
pub use lookups::{canonical_combining_class, is_combining_mark};
}


Expand Down
89 changes: 89 additions & 0 deletions src/lookups.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright 2019 The Rust Project Developers. See the COPYRIGHT
// file at the top-level directory of this distribution and at
// http://rust-lang.org/COPYRIGHT.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

//! Lookups of unicode properties using minimal perfect hashing.
use perfect_hash::mph_lookup;
use tables::*;

/// Look up the canonical combining class for a codepoint.
///
/// The value returned is as defined in the Unicode Character Database.
pub fn canonical_combining_class(c: char) -> u8 {
mph_lookup(c.into(), CANONICAL_COMBINING_CLASS_SALT, CANONICAL_COMBINING_CLASS_KV,
u8_lookup_fk, u8_lookup_fv, 0)
}

pub(crate) fn composition_table(c1: char, c2: char) -> Option<char> {
if c1 < '\u{10000}' && c2 < '\u{10000}' {
mph_lookup((c1 as u32) << 16 | (c2 as u32),
COMPOSITION_TABLE_SALT, COMPOSITION_TABLE_KV,
pair_lookup_fk, pair_lookup_fv_opt, None)
} else {
composition_table_astral(c1, c2)
}
}

pub(crate) fn canonical_fully_decomposed(c: char) -> Option<&'static [char]> {
mph_lookup(c.into(), CANONICAL_DECOMPOSED_SALT, CANONICAL_DECOMPOSED_KV,
pair_lookup_fk, pair_lookup_fv_opt, None)
}

pub(crate) fn compatibility_fully_decomposed(c: char) -> Option<&'static [char]> {
mph_lookup(c.into(), COMPATIBILITY_DECOMPOSED_SALT, COMPATIBILITY_DECOMPOSED_KV,
pair_lookup_fk, pair_lookup_fv_opt, None)
}

/// Return whether the given character is a combining mark (`General_Category=Mark`)
pub fn is_combining_mark(c: char) -> bool {
mph_lookup(c.into(), COMBINING_MARK_SALT, COMBINING_MARK_KV,
bool_lookup_fk, bool_lookup_fv, false)
}

pub fn stream_safe_trailing_nonstarters(c: char) -> usize {
mph_lookup(c.into(), TRAILING_NONSTARTERS_SALT, TRAILING_NONSTARTERS_KV,
u8_lookup_fk, u8_lookup_fv, 0) as usize
}

/// Extract the key in a 24 bit key and 8 bit value packed in a u32.
#[inline]
fn u8_lookup_fk(kv: u32) -> u32 {
kv >> 8
}

/// Extract the value in a 24 bit key and 8 bit value packed in a u32.
#[inline]
fn u8_lookup_fv(kv: u32) -> u8 {
(kv & 0xff) as u8
}

/// Extract the key for a boolean lookup.
#[inline]
fn bool_lookup_fk(kv: u32) -> u32 {
kv
}

/// Extract the value for a boolean lookup.
#[inline]
fn bool_lookup_fv(_kv: u32) -> bool {
true
}

/// Extract the key in a pair.
#[inline]
fn pair_lookup_fk<T>(kv: (u32, T)) -> u32 {
kv.0
}

/// Extract the value in a pair, returning an option.
#[inline]
fn pair_lookup_fv_opt<T>(kv: (u32, T)) -> Option<T> {
Some(kv.1)
}
Loading

0 comments on commit 7c23cc9

Please sign in to comment.