diff --git a/bench/__init__.py b/bench/__init__.py index 92c668b..3dfb05c 100644 --- a/bench/__init__.py +++ b/bench/__init__.py @@ -34,7 +34,6 @@ def __init__(self, lang_code: str, ids: Set[str]): lt_code, remote_server="http://localhost:8081/" ) self.tool.disabled_rules = { - "MORFOLOGIK_RULE_EN_US", "GERMAN_SPELLER_RULE", "COMMA_PARENTHESIS_WHITESPACE", "DOUBLE_PUNCTUATION", @@ -117,6 +116,7 @@ class NLPRule: def __init__(self, lang_code: str): self.tokenizer = nlprule.Tokenizer(f"storage/{lang_code}_tokenizer.bin") self.rules = nlprule.Rules(f"storage/{lang_code}_rules.bin", self.tokenizer) + self.rules.spell.options.variant = "en_US" def suggest(self, sentence: str) -> Set[Suggestion]: suggestions = { diff --git a/build/README.md b/build/README.md index c826acc..458f7da 100644 --- a/build/README.md +++ b/build/README.md @@ -79,6 +79,7 @@ python build/make_build_dir.py \ --chunker_token_model=$HOME/Downloads/nlprule/en-token.bin \ --chunker_pos_model=$HOME/Downloads/nlprule/en-pos-maxent.bin \ --chunker_chunk_model=$HOME/Downloads/nlprule/en-chunker.bin \ + --spell_map_path=$LT_PATH/org/languagetool/rules/en/contractions.txt \ --out_dir=data/en ``` diff --git a/build/make_build_dir.py b/build/make_build_dir.py index 04fbfae..80f9b7d 100644 --- a/build/make_build_dir.py +++ b/build/make_build_dir.py @@ -6,6 +6,7 @@ from zipfile import ZipFile import lxml.etree as ET import wordfreq +from glob import glob from chardet.universaldetector import UniversalDetector from chunker import write_chunker # type: ignore @@ -59,7 +60,7 @@ def copy_lt_files(out_dir, lt_dir, lang_code): canonicalize(out_dir / xmlfile) -def dump_dictionary(out_path, lt_dir, tag_dict_path, tag_info_path): +def dump_dict(out_path, lt_dir, tag_dict_path, tag_info_path): # dump dictionary, see https://dev.languagetool.org/developing-a-tagger-dictionary os.system( f"java -cp {lt_dir / 'languagetool.jar'} org.languagetool.tools.DictionaryExporter " @@ -83,7 +84,46 @@ def dump_dictionary(out_path, lt_dir, tag_dict_path, tag_info_path): dump_bytes = open(out_path, "rb").read() with open(out_path, "w") as f: - f.write(dump_bytes.decode(result["encoding"])) + f.write(dump_bytes.decode(result["encoding"] or "utf-8")) + + +def proc_spelling_text(in_paths, out_path, lang_code): + with open(out_path, "w") as f: + for in_path in in_paths: + if in_path.exists(): + for line in open(in_path): + # strip comments + comment_index = line.find("#") + if comment_index != -1: + line = line[:comment_index] + + line = line.strip() + if len(line) == 0: + continue + + try: + word, suffix = line.split("/") + + assert lang_code == "de", "Flags are only supported for German!" + + for flag in suffix: + assert flag != "Ä" + if flag == "A" and word.endswith("e"): + flag = "Ä" + + f.write(word + "\n") + + for ending in { + "S": ["s"], + "N": ["n"], + "E": ["e"], + "F": ["in"], + "A": ["e", "er", "es", "en", "em"], + "Ä": ["r", "s", "n", "m"], + }[flag]: + f.write(word + ending + "\n") + except ValueError: + f.write(line + "\n") if __name__ == "__main__": @@ -138,6 +178,12 @@ def dump_dictionary(out_path, lt_dir, tag_dict_path, tag_info_path): default=None, help="Path to the OpenNLP chunker binary. See token model message for details.", ) + parser.add_argument( + "--spell_map_path", + default=None, + action="append", + help="Paths to files containing a mapping from incorrect words to correct ones e.g. contractions.txt for English.", + ) parser.add_argument( "--out_dir", type=lambda p: Path(p).absolute(), @@ -149,12 +195,72 @@ def dump_dictionary(out_path, lt_dir, tag_dict_path, tag_info_path): write_freqlist(open(args.out_dir / "common.txt", "w"), args.lang_code) copy_lt_files(args.out_dir, args.lt_dir, args.lang_code) - dump_dictionary( + + # tagger dictionary + dump_dict( args.out_dir / "tags" / "output.dump", args.lt_dir, args.tag_dict_path, args.tag_info_path, ) + + # spell dictionaries + (args.out_dir / "spell").mkdir() + for dic in glob( + str( + args.lt_dir + / "org" + / "languagetool" + / "resource" + / args.lang_code + / "hunspell" + / "*.dict" + ) + ): + dic = Path(dic) + info = Path(dic).with_suffix(".info") + + variant_name = dic.stem + + dump_dict( + args.out_dir / "spell" / f"{variant_name}.dump", args.lt_dir, dic, info, + ) + proc_spelling_text( + [ + ( + dic / ".." / ("spelling_" + variant_name.replace("_", "-") + ".txt") + ).resolve(), + ( + dic / ".." / ("spelling-" + variant_name.replace("_", "-") + ".txt") + ).resolve(), + ], + args.out_dir / "spell" / f"{variant_name}.txt", + args.lang_code, + ) + + proc_spelling_text( + [ + args.lt_dir + / "org" + / "languagetool" + / "resource" + / args.lang_code + / "hunspell" + / "spelling.txt" + ], + args.out_dir / "spell" / "spelling.txt", + args.lang_code, + ) + + with open(args.out_dir / "spell" / "map.txt", "w") as f: + for path in args.spell_map_path or []: + for line in open(path): + if line.startswith("#"): + continue + + assert "#" not in line + f.write(line) + if ( args.chunker_token_model is not None and args.chunker_pos_model is not None diff --git a/build/src/lib.rs b/build/src/lib.rs index fbc18c9..281b3b9 100644 --- a/build/src/lib.rs +++ b/build/src/lib.rs @@ -5,7 +5,7 @@ use flate2::bufread::GzDecoder; use fs::File; use fs_err as fs; use nlprule::{compile, rules_filename, tokenizer_filename}; -use std::fs::Permissions; +use std::{fs::Permissions, sync::Arc}; use std::{ io::{self, BufReader, BufWriter, Cursor, Read}, path::{Path, PathBuf}, @@ -469,10 +469,11 @@ impl BinaryBuilder { let tokenizer_out = self.out_dir.join(tokenizer_filename(lang_code)); let rules_out = self.out_dir.join(rules_filename(lang_code)); - nlprule::Rules::new(rules_out) - .map_err(|e| Error::ValidationFailed(lang_code.to_owned(), Binary::Rules, e))?; - nlprule::Tokenizer::new(tokenizer_out) + let tokenizer = nlprule::Tokenizer::new(tokenizer_out) .map_err(|e| Error::ValidationFailed(lang_code.to_owned(), Binary::Tokenizer, e))?; + + nlprule::Rules::new(rules_out, Arc::new(tokenizer)) + .map_err(|e| Error::ValidationFailed(lang_code.to_owned(), Binary::Rules, e))?; } Ok(()) diff --git a/nlprule/Cargo.toml b/nlprule/Cargo.toml index 0f49004..a0c59d3 100644 --- a/nlprule/Cargo.toml +++ b/nlprule/Cargo.toml @@ -30,6 +30,8 @@ half = { version = "1.7", features = ["serde"] } srx = { version = "^0.1.2", features = ["serde"] } lazycell = "1" cfg-if = "1" +fnv = "1" +unicode_categories = "0.1" rayon-cond = "0.1" rayon = "1.5" diff --git a/nlprule/build.rs b/nlprule/build.rs index 8eb2ad1..543bec2 100644 --- a/nlprule/build.rs +++ b/nlprule/build.rs @@ -20,6 +20,7 @@ fn main() { ("tokenizer.json", "tokenizer_configs.json"), ("rules.json", "rules_configs.json"), ("tagger.json", "tagger_configs.json"), + ("spellchecker.json", "spellchecker_configs.json"), ] { let mut config_map: HashMap = HashMap::new(); diff --git a/nlprule/configs/de/spellchecker.json b/nlprule/configs/de/spellchecker.json new file mode 100644 index 0000000..0500649 --- /dev/null +++ b/nlprule/configs/de/spellchecker.json @@ -0,0 +1,8 @@ +{ + "variants": [ + "de_AT", + "de_DE", + "de_CH" + ], + "split_hyphens": true +} \ No newline at end of file diff --git a/nlprule/configs/en/rules.json b/nlprule/configs/en/rules.json index a1a27a0..2fd88b3 100644 --- a/nlprule/configs/en/rules.json +++ b/nlprule/configs/en/rules.json @@ -3,5 +3,6 @@ "ignore_ids": [ "GRAMMAR/PRP_MD_NN/2", "TYPOS/VERB_APOSTROPHE_S/3" - ] + ], + "split_hyphens": true } \ No newline at end of file diff --git a/nlprule/configs/en/spellchecker.json b/nlprule/configs/en/spellchecker.json new file mode 100644 index 0000000..d9774ff --- /dev/null +++ b/nlprule/configs/en/spellchecker.json @@ -0,0 +1,8 @@ +{ + "variants": [ + "en_GB", + "en_US", + "en_AU" + ], + "split_hyphens": true +} \ No newline at end of file diff --git a/nlprule/configs/es/spellchecker.json b/nlprule/configs/es/spellchecker.json new file mode 100644 index 0000000..51a957d --- /dev/null +++ b/nlprule/configs/es/spellchecker.json @@ -0,0 +1,4 @@ +{ + "variants": [], + "split_hyphens": true +} \ No newline at end of file diff --git a/nlprule/src/bin/run.rs b/nlprule/src/bin/run.rs index c6da086..9825769 100644 --- a/nlprule/src/bin/run.rs +++ b/nlprule/src/bin/run.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use clap::Clap; use nlprule::{rules::Rules, tokenizer::Tokenizer}; @@ -18,11 +20,12 @@ fn main() { env_logger::init(); let opts = Opts::parse(); - let tokenizer = Tokenizer::new(opts.tokenizer).unwrap(); - let rules = Rules::new(opts.rules).unwrap(); + let tokenizer = Arc::new(Tokenizer::new(opts.tokenizer).unwrap()); + let mut rules = Rules::new(opts.rules, tokenizer.clone()).unwrap(); + rules.spell_mut().options_mut().variant = Some(rules.spell().variant("en_GB").unwrap()); let tokens = tokenizer.pipe(&opts.text); println!("Tokens: {:#?}", tokens); - println!("Suggestions: {:#?}", rules.suggest(&opts.text, &tokenizer)); + println!("Suggestions: {:#?}", rules.suggest(&opts.text)); } diff --git a/nlprule/src/bin/test.rs b/nlprule/src/bin/test.rs index 3669a8e..2a81ed6 100644 --- a/nlprule/src/bin/test.rs +++ b/nlprule/src/bin/test.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use clap::Clap; use nlprule::{rules::Rules, tokenizer::Tokenizer}; @@ -19,8 +21,8 @@ fn main() { env_logger::init(); let opts = Opts::parse(); - let tokenizer = Tokenizer::new(opts.tokenizer).unwrap(); - let rules_container = Rules::new(opts.rules).unwrap(); + let tokenizer = Arc::new(Tokenizer::new(opts.tokenizer).unwrap()); + let rules_container = Rules::new(opts.rules, tokenizer.clone()).unwrap(); let rules = rules_container.rules(); println!("Runnable rules: {}", rules.len()); diff --git a/nlprule/src/compile/impls.rs b/nlprule/src/compile/impls.rs index c5226a2..c1b41bd 100644 --- a/nlprule/src/compile/impls.rs +++ b/nlprule/src/compile/impls.rs @@ -4,10 +4,12 @@ use indexmap::IndexMap; use log::warn; use serde::{Deserialize, Serialize}; use std::{ + cmp, collections::{HashMap, HashSet}, hash::{Hash, Hasher}, io::{self, BufRead, BufReader}, path::Path, + sync::Arc, }; use crate::{ @@ -17,10 +19,12 @@ use crate::{ composition::{GraphId, Matcher, PosMatcher, TextMatcher}, Engine, }, + grammar::PosReplacer, id::Category, DisambiguationRule, MatchGraph, Rule, }, - rules::{Rules, RulesLangOptions, RulesOptions}, + rules::{Rules, RulesLangOptions}, + spell::{Spell, SpellInt, SpellLangOptions}, tokenizer::{ chunk, multiword::{MultiwordTagger, MultiwordTaggerFields}, @@ -33,6 +37,201 @@ use crate::{ use super::{parse_structure::BuildInfo, Error}; +impl Spell { + fn new( + fst: Vec, + multiwords: DefaultHashMap, SpellInt)>>, + max_freq: usize, + map: DefaultHashMap, + lang_options: SpellLangOptions, + ) -> Self { + let mut spell = Spell { + fst, + multiwords, + max_freq, + map, + lang_options, + ..Default::default() + }; + spell.ingest_options(); + spell + } + + #[allow(clippy::clippy::too_many_arguments)] // lots of arguments here but not easily avoidable + fn add_line( + word: &str, + freq: usize, + words: &mut DefaultHashMap, + multiwords: &mut DefaultHashMap, SpellInt)>>, + variant_index: usize, + // in some LT lists an underline denotes a prefix such that e.g. "hin_reiten" also adds "hingeritten" + underline_denotes_prefix: bool, + build_info: &mut BuildInfo, + tokenizer: &Tokenizer, + ) { + let tokens = tokenizer.get_token_strs(word); + + if tokens.len() > 1 { + assert!(!(underline_denotes_prefix && word.contains('_'))); // not supported in entries spanning multiple tokens + + let mut int = SpellInt::default(); + int.add_variant(variant_index); + + // we do not add the frequency for multiwords - they are not used for suggestions, just to check validity + + multiwords + .entry(tokens[0].to_owned()) + .or_insert_with(Vec::new) + .push(( + tokens[1..] + .iter() + .filter(|x| !x.trim().is_empty()) + .map(|x| (*x).to_owned()) + .collect(), + int, + )); + } else if word.contains('_') && underline_denotes_prefix { + assert!(!word.contains('\\')); // escaped underlines are not supported + let mut parts = word.split('_'); + + let prefix = parts.next().unwrap(); + let suffix = parts.next().unwrap(); + + // this will presumably always be covered by the extra suffixes, but add it just to make sure + let value = words.entry(format!("{}{}", prefix, suffix)).or_default(); + value.add_variant(variant_index); + value.update_freq(freq); + + let replacer = PosReplacer { + matcher: PosMatcher::new( + Matcher::new_regex(Regex::new("^VER:.*".into()), false, true), + build_info, + ), + }; + + for new_suffix in replacer.apply(suffix, tokenizer) { + let new_word = format!("{}{}", prefix, new_suffix); + + let value = words.entry(new_word).or_default(); + value.add_variant(variant_index); + value.update_freq(freq); + } + } else { + let value = words.entry((*word).to_owned()).or_default(); + + value.update_freq(freq); + value.add_variant(variant_index); + } + } + + pub(in crate::compile) fn from_dumps( + spell_dir_path: impl AsRef, + map_path: impl AsRef, + global_word_path: impl AsRef, + build_info: &mut BuildInfo, + lang_options: SpellLangOptions, + tokenizer: &Tokenizer, + ) -> io::Result { + let mut words: DefaultHashMap = DefaultHashMap::new(); + let mut multiwords: DefaultHashMap, SpellInt)>> = + DefaultHashMap::new(); + let mut max_freq = 0; + + for (i, variant) in lang_options.variants.iter().enumerate() { + let spell_path = spell_dir_path + .as_ref() + .join(variant.as_str()) + .with_extension("dump"); + + let reader = BufReader::new(File::open(&spell_path)?); + for line in reader.lines() { + match line? + .trim() + .split_whitespace() + .collect::>() + .as_slice() + { + [freq, word] => { + // frequency is denoted as letters from A to Z in LanguageTool where A is the least frequent. + let freq = freq.chars().next().expect("freq must have one char - would not have been yielded by split_whitespace otherwise.") as usize - 'A' as usize; + + Spell::add_line( + word, + freq, + &mut words, + &mut multiwords, + i, + false, + build_info, + tokenizer, + ); + max_freq = cmp::max(max_freq, freq); + } + _ => continue, + } + } + + let global_word_reader = BufReader::new(File::open(global_word_path.as_ref())?); + + let extra_word_path = spell_path.with_extension("txt"); + let reader = BufReader::new(File::open(&extra_word_path)?); + for line in reader.lines().chain(global_word_reader.lines()) { + let line = line?; + let word = line.trim(); + + Spell::add_line( + word, + 0, + &mut words, + &mut multiwords, + i, + true, + build_info, + tokenizer, + ); + } + } + let mut words: Vec<_> = words + .into_iter() + .map(|(key, value)| (key, value.as_u64())) + .collect(); + words.sort_by(|(a, _), (b, _)| a.cmp(b)); + + let fst = + fst::Map::from_iter(words.into_iter()).expect("words are lexicographically sorted."); + + let mut map = DefaultHashMap::new(); + let reader = BufReader::new(File::open(map_path.as_ref())?); + for line in reader.lines() { + let line = line?; + + let mut parts = line.split('='); + let wrong = parts + .next() + .expect("spell map line must have part before =") + .to_owned(); + let right = parts + .next() + .expect("spell map line must have part after =") + .to_owned(); + + // map lookup happens on token level, so the key has to be exactly one token + assert_eq!(tokenizer.get_token_strs(&wrong).len(), 1); + + map.insert(wrong, right); + assert!(parts.next().is_none()); + } + + Ok(Spell::new( + fst.into_fst().to_vec(), + multiwords, + max_freq, + map, + lang_options, + )) + } +} + impl Tagger { fn get_lines, S2: AsRef>( paths: &[S1], @@ -263,6 +462,8 @@ impl Rules { pub(in crate::compile) fn from_xml>( path: P, build_info: &mut BuildInfo, + spell: Spell, + tokenizer: Arc, options: RulesLangOptions, ) -> Self { let rules = super::parse_structure::read_rules(path); @@ -357,7 +558,8 @@ impl Rules { Rules { rules, - options: RulesOptions::default(), + spell, + tokenizer, } } } diff --git a/nlprule/src/compile/mod.rs b/nlprule/src/compile/mod.rs index 95a8be6..1654a19 100644 --- a/nlprule/src/compile/mod.rs +++ b/nlprule/src/compile/mod.rs @@ -14,8 +14,9 @@ use std::{ use crate::{ rules::Rules, + spell::Spell, tokenizer::{chunk::Chunker, multiword::MultiwordTagger, tag::Tagger, Tokenizer}, - types::DefaultHasher, + types::*, }; use log::info; @@ -35,12 +36,16 @@ struct BuildFilePaths { disambiguation_path: PathBuf, grammar_path: PathBuf, multiword_tag_path: PathBuf, - common_words_path: PathBuf, regex_cache_path: PathBuf, srx_path: PathBuf, + common_words_path: PathBuf, + spell_dir_path: PathBuf, + spell_map_path: PathBuf, + spell_extra_path: PathBuf, } impl BuildFilePaths { + // this has to be kept in sync with the paths the builder in build/make_build_dir.py stores the resources at fn new>(build_dir: P) -> Self { let p = build_dir.as_ref(); BuildFilePaths { @@ -51,9 +56,12 @@ impl BuildFilePaths { disambiguation_path: p.join("disambiguation.xml"), grammar_path: p.join("grammar.xml"), multiword_tag_path: p.join("tags/multiwords.txt"), - common_words_path: p.join("common.txt"), regex_cache_path: p.join("regex_cache.bin"), srx_path: p.join("segment.srx"), + common_words_path: p.join("common.txt"), + spell_dir_path: p.join("spell"), + spell_map_path: p.join("spell/map.txt"), + spell_extra_path: p.join("spell/spelling.txt"), } } } @@ -81,6 +89,8 @@ pub enum Error { Unimplemented(String), #[error("error parsing to integer: {0}")] ParseError(#[from] ParseIntError), + #[error("nlprule error: {0}")] + NLPRuleError(#[from] crate::Error), #[error("unknown error")] Other(#[from] Box), } @@ -120,6 +130,13 @@ pub fn compile( lang_code: lang_code.clone(), })?; + let spellchecker_lang_options = + utils::spellchecker_lang_options(&lang_code).ok_or_else(|| { + Error::LanguageOptionsDoNotExist { + lang_code: lang_code.clone(), + } + })?; + info!("Creating tagger."); let tagger = Tagger::from_dumps( &paths.tag_paths, @@ -163,6 +180,18 @@ pub fn compile( } else { None }; + + info!("Creating tokenizer."); + + let mut tokenizer = Tokenizer::from_xml( + &paths.disambiguation_path, + &mut build_info, + chunker, + None, + srx::SRX::from_str(&fs::read_to_string(&paths.srx_path)?)?.language_rules(lang_code), + tokenizer_lang_options, + )?; + let multiword_tagger = if paths.multiword_tag_path.exists() { info!( "{} exists. Building multiword tagger.", @@ -175,22 +204,30 @@ pub fn compile( } else { None }; + tokenizer.multiword_tagger = multiword_tagger; + tokenizer.to_writer(&mut tokenizer_dest)?; - info!("Creating tokenizer."); - let tokenizer = Tokenizer::from_xml( - &paths.disambiguation_path, + info!("Creating spellchecker."); + + let spellchecker = Spell::from_dumps( + paths.spell_dir_path, + paths.spell_map_path, + paths.spell_extra_path, &mut build_info, - chunker, - multiword_tagger, - srx::SRX::from_str(&fs::read_to_string(&paths.srx_path)?)?.language_rules(lang_code), - tokenizer_lang_options, + spellchecker_lang_options, + &tokenizer, )?; - bincode::serialize_into(&mut tokenizer_dest, &tokenizer)?; - info!("Creating grammar rules."); - let rules = Rules::from_xml(&paths.grammar_path, &mut build_info, rules_lang_options); - bincode::serialize_into(&mut rules_dest, &rules)?; + + let rules = Rules::from_xml( + &paths.grammar_path, + &mut build_info, + spellchecker, + Arc::new(tokenizer), + rules_lang_options, + ); + rules.to_writer(&mut rules_dest)?; // we need to write the regex cache after building the rules, otherwise it isn't fully populated let f = BufWriter::new(File::create(&paths.regex_cache_path)?); diff --git a/nlprule/src/compile/parse_structure.rs b/nlprule/src/compile/parse_structure.rs index 5fe417c..44aeb4d 100644 --- a/nlprule/src/compile/parse_structure.rs +++ b/nlprule/src/compile/parse_structure.rs @@ -1054,7 +1054,7 @@ impl DisambiguationRule { }) .collect(), )), - Some("ignore_spelling") => Ok(Disambiguation::Nop), // ignore_spelling can be ignored since we dont check spelling + Some("ignore_spelling") => Ok(Disambiguation::IgnoreSpelling), Some("immunize") => Ok(Disambiguation::Nop), // immunize can probably not be ignored Some("filterall") => { let mut disambig = Vec::new(); diff --git a/nlprule/src/compile/utils.rs b/nlprule/src/compile/utils.rs index 73b5322..e86d30e 100644 --- a/nlprule/src/compile/utils.rs +++ b/nlprule/src/compile/utils.rs @@ -1,4 +1,4 @@ -use crate::{rules::RulesLangOptions, tokenizer::TokenizerLangOptions}; +use crate::{rules::RulesLangOptions, spell::SpellLangOptions, tokenizer::TokenizerLangOptions}; use crate::{tokenizer::tag::TaggerLangOptions, types::*}; use lazy_static::lazy_static; @@ -35,6 +35,17 @@ lazy_static! { }; } +lazy_static! { + static ref SPELLCHECKER_LANG_OPTIONS: DefaultHashMap = { + serde_json::from_slice(include_bytes!(concat!( + env!("OUT_DIR"), + "/", + "spellchecker_configs.json" + ))) + .expect("tagger configs must be valid JSON") + }; +} + /// Gets the tokenizer language options for the language code pub(crate) fn tokenizer_lang_options(lang_code: &str) -> Option { TOKENIZER_LANG_OPTIONS.get(lang_code).cloned() @@ -50,6 +61,11 @@ pub(crate) fn tagger_lang_options(lang_code: &str) -> Option TAGGER_LANG_OPTIONS.get(lang_code).cloned() } +/// Gets the spellchecker language options for the language code +pub(crate) fn spellchecker_lang_options(lang_code: &str) -> Option { + SPELLCHECKER_LANG_OPTIONS.get(lang_code).cloned() +} + pub(crate) use regex::from_java_regex; mod regex { diff --git a/nlprule/src/lib.rs b/nlprule/src/lib.rs index 0a3fd5e..6b2b047 100644 --- a/nlprule/src/lib.rs +++ b/nlprule/src/lib.rs @@ -5,32 +5,35 @@ //! - A [Tokenizer][tokenizer::Tokenizer] to split a text into tokens and analyze it by chunking, lemmatizing and part-of-speech tagging. Can also be used independently of the grammatical rules. //! - A [Rules][rules::Rules] structure containing a set of grammatical error correction rules. //! -//! # Example: correct a text +//! # Examples //! +//! Correct a text: //! ```no_run //! use nlprule::{Tokenizer, Rules}; //! //! let tokenizer = Tokenizer::new("path/to/en_tokenizer.bin")?; -//! let rules = Rules::new("path/to/en_rules.bin")?; +//! let mut rules = Rules::new("path/to/en_rules.bin", tokenizer.into())?; +//! // enable spellchecking +//! rules.spell_mut().options_mut().variant = Some(rules.spell().variant("en_GB")?); //! //! assert_eq!( -//! rules.correct("She was not been here since Monday.", &tokenizer), -//! String::from("She was not here since Monday.") +//! rules.correct("I belive she was not been here since Monday."), +//! String::from("I believe she was not here since Monday.") //! ); //! # Ok::<(), nlprule::Error>(()) //! ``` //! -//! # Example: get suggestions and correct a text +//! Get suggestions and correct a text: //! //! ```no_run //! use nlprule::{Tokenizer, Rules, types::Suggestion, rules::apply_suggestions}; //! //! let tokenizer = Tokenizer::new("path/to/en_tokenizer.bin")?; -//! let rules = Rules::new("path/to/en_rules.bin")?; +//! let rules = Rules::new("path/to/en_rules.bin", tokenizer.into())?; //! //! let text = "She was not been here since Monday."; //! -//! let suggestions = rules.suggest(text, &tokenizer); +//! let suggestions = rules.suggest(text); //! assert_eq!( //! suggestions, //! vec![Suggestion { @@ -48,6 +51,28 @@ //! # Ok::<(), nlprule::Error>(()) //! ``` //! +//! Tokenize & analyze a text: +//! +//! ```no_run +//! use nlprule::Tokenizer; +//! +//! let tokenizer = Tokenizer::new("path/to/en_tokenizer.bin")?; +//! +//! let text = "A brief example is shown."; +//! +//! // returns a vector over sentences +//! // we assume this is one sentence so we take the first element +//! let tokens = tokenizer.pipe(text).remove(0); +//! +//! println!("{:#?}", tokens); +//! // token at index zero is the special SENT_START token - generally not interesting +//! assert_eq!(tokens[2].word.text.as_ref(), "brief"); +//! assert_eq!(tokens[2].word.tags[0].pos.as_ref(), "JJ"); +//! assert_eq!(tokens[2].chunks, vec!["I-NP-singular"]); +//! // some other information like char / byte span, lemmas etc. is also set! +//! # Ok::<(), nlprule::Error>(()) +//! ``` +//! --- //! Binaries are distributed with [Github releases](https://github.com/bminixhofer/nlprule/releases). //! //! # The 't lifetime @@ -63,6 +88,7 @@ pub mod compile; mod filter; pub mod rule; pub mod rules; +pub mod spell; pub mod tokenizer; pub mod types; pub(crate) mod utils; @@ -77,6 +103,8 @@ pub enum Error { Io(#[from] io::Error), #[error("deserialization error: {0}")] Deserialization(#[from] bincode::Error), + #[error("unknown language variant: \"{0}\". known variants are: {1:?}.")] + UnknownVariant(String, Vec), } /// Gets the canonical filename for the tokenizer binary for a language code in ISO 639-1 (two-letter) format. diff --git a/nlprule/src/rule/disambiguation.rs b/nlprule/src/rule/disambiguation.rs index 2fc80b4..6f067df 100644 --- a/nlprule/src/rule/disambiguation.rs +++ b/nlprule/src/rule/disambiguation.rs @@ -44,6 +44,7 @@ pub enum Disambiguation { Replace(Vec), Filter(Vec>>), Unify(Vec>, Vec>, Vec), + IgnoreSpelling, Nop, } @@ -190,6 +191,13 @@ impl Disambiguation { } } } + Disambiguation::IgnoreSpelling => { + for group in groups { + for token in group { + token.ignore_spelling = true; + } + } + } Disambiguation::Nop => {} } } diff --git a/nlprule/src/rule/engine/composition.rs b/nlprule/src/rule/engine/composition.rs index d05986e..378ea8a 100644 --- a/nlprule/src/rule/engine/composition.rs +++ b/nlprule/src/rule/engine/composition.rs @@ -4,7 +4,7 @@ use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; use unicase::UniCase; -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct Matcher { pub matcher: either::Either, Regex>, pub negate: bool, @@ -68,7 +68,7 @@ impl Matcher { } } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub(crate) struct TextMatcher { pub(crate) matcher: Matcher, pub(crate) set: Option>, @@ -107,7 +107,7 @@ impl PosMatcher { } } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct WordDataMatcher { pub(crate) pos_matcher: Option, pub(crate) inflect_matcher: Option, @@ -141,7 +141,7 @@ impl WordDataMatcher { } } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct Quantifier { pub min: usize, pub max: usize, @@ -153,7 +153,7 @@ pub trait Atomable: Send + Sync { } #[enum_dispatch(Atomable)] -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub enum Atom { ChunkAtom(concrete::ChunkAtom), SpaceBeforeAtom(concrete::SpaceBeforeAtom), @@ -171,7 +171,7 @@ pub mod concrete { use super::{Atomable, MatchGraph, Matcher, TextMatcher, Token, WordDataMatcher}; use serde::{Deserialize, Serialize}; - #[derive(Debug, Serialize, Deserialize)] + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct TextAtom { pub(crate) matcher: TextMatcher, } @@ -183,7 +183,7 @@ pub mod concrete { } } - #[derive(Debug, Serialize, Deserialize)] + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChunkAtom { pub(crate) matcher: Matcher, } @@ -195,7 +195,7 @@ pub mod concrete { } } - #[derive(Debug, Serialize, Deserialize)] + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SpaceBeforeAtom { pub(crate) value: bool, } @@ -206,7 +206,7 @@ pub mod concrete { } } - #[derive(Debug, Serialize, Deserialize)] + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct WordDataAtom { pub(crate) matcher: WordDataMatcher, pub(crate) case_sensitive: bool, @@ -222,7 +222,7 @@ pub mod concrete { } } -#[derive(Debug, Serialize, Deserialize, Default)] +#[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct TrueAtom {} impl Atomable for TrueAtom { @@ -231,7 +231,7 @@ impl Atomable for TrueAtom { } } -#[derive(Debug, Serialize, Deserialize, Default)] +#[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct FalseAtom {} impl Atomable for FalseAtom { @@ -240,7 +240,7 @@ impl Atomable for FalseAtom { } } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct AndAtom { pub(crate) atoms: Vec, } @@ -253,7 +253,7 @@ impl Atomable for AndAtom { } } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct OrAtom { pub(crate) atoms: Vec, } @@ -266,7 +266,7 @@ impl Atomable for OrAtom { } } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct NotAtom { pub(crate) atom: Box, } @@ -277,7 +277,7 @@ impl Atomable for NotAtom { } } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct OffsetAtom { pub(crate) atom: Box, pub(crate) offset: isize, @@ -449,7 +449,7 @@ impl<'t> MatchGraph<'t> { } } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct Part { pub atom: Atom, pub quantifier: Quantifier, @@ -458,7 +458,7 @@ pub struct Part { pub unify: Option, } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct Composition { pub(crate) parts: Vec, pub(crate) id_to_idx: DefaultHashMap, diff --git a/nlprule/src/rule/engine/mod.rs b/nlprule/src/rule/engine/mod.rs index 75d933f..0667267 100644 --- a/nlprule/src/rule/engine/mod.rs +++ b/nlprule/src/rule/engine/mod.rs @@ -9,7 +9,7 @@ use composition::{Composition, Group, MatchGraph}; use self::composition::GraphId; -#[derive(Serialize, Deserialize, Debug)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct TokenEngine { pub(crate) composition: Composition, pub(crate) antipatterns: Vec, @@ -53,7 +53,7 @@ impl TokenEngine { } } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub enum Engine { Token(TokenEngine), // regex with the `fancy_regex` backend is large on the stack diff --git a/nlprule/src/rule/grammar.rs b/nlprule/src/rule/grammar.rs index a289ae3..3b10a07 100644 --- a/nlprule/src/rule/grammar.rs +++ b/nlprule/src/rule/grammar.rs @@ -16,7 +16,7 @@ impl std::cmp::PartialEq for Suggestion { } } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub enum Conversion { Nop, AllLower, @@ -38,7 +38,7 @@ impl Conversion { } /// An example associated with a [Rule][crate::rule::Rule]. -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct Example { pub(crate) text: String, pub(crate) suggestion: Option, @@ -58,13 +58,13 @@ impl Example { } } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct PosReplacer { pub(crate) matcher: PosMatcher, } impl PosReplacer { - fn apply(&self, text: &str, tokenizer: &Tokenizer) -> Option { + pub fn apply(&self, text: &str, tokenizer: &Tokenizer) -> Vec { let mut candidates: Vec<_> = tokenizer .tagger() .get_tags(text) @@ -75,13 +75,13 @@ impl PosReplacer { .get_group_members(&x.lemma.as_ref().to_string()); let mut data = Vec::new(); for word in group_words { - if let Some(i) = tokenizer + if let Some(_i) = tokenizer .tagger() .get_tags(word) .iter() .position(|x| self.matcher.is_match(&x.pos)) { - data.push((word.to_string(), i)); + data.push(word.to_string()); } } data @@ -89,16 +89,13 @@ impl PosReplacer { .rev() .flatten() .collect(); - candidates.sort_by(|(_, a), (_, b)| a.cmp(b)); - if candidates.is_empty() { - None - } else { - Some(candidates.remove(0).0) - } + candidates.sort_unstable(); + candidates.dedup(); + candidates } } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct Match { pub(crate) id: GraphId, pub(crate) conversion: Conversion, @@ -111,7 +108,7 @@ impl Match { let text = graph.by_id(self.id).text(graph.tokens()[0].sentence); let mut text = if let Some(replacer) = &self.pos_replacer { - replacer.apply(text, tokenizer)? + replacer.apply(text, tokenizer).into_iter().next()? } else { text.to_string() }; @@ -131,14 +128,14 @@ impl Match { } } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub enum SynthesizerPart { Text(String), // Regex with the `fancy_regex` backend is large on the stack Match(Box), } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct Synthesizer { pub(crate) use_titlecase_adjust: bool, pub(crate) parts: Vec, diff --git a/nlprule/src/rule/mod.rs b/nlprule/src/rule/mod.rs index 42410c3..799d2f5 100644 --- a/nlprule/src/rule/mod.rs +++ b/nlprule/src/rule/mod.rs @@ -31,7 +31,7 @@ use self::{ /// A *Unification* makes an otherwise matching pattern invalid if no combination of its filters /// matches all tokens marked with "unify". /// Can also be negated. -#[derive(Serialize, Deserialize, Debug)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub(crate) struct Unification { pub(crate) mask: Vec>, pub(crate) filters: Vec>, @@ -375,7 +375,7 @@ impl<'a, 't> Iterator for Suggestions<'a, 't> { /// He dosn't know about it. /// /// ``` -#[derive(Serialize, Deserialize, Debug)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct Rule { pub(crate) id: Index, pub(crate) engine: Engine, diff --git a/nlprule/src/rules.rs b/nlprule/src/rules.rs index b7b3203..a20e673 100644 --- a/nlprule/src/rules.rs +++ b/nlprule/src/rules.rs @@ -1,20 +1,16 @@ //! Sets of grammatical error correction rules. -use crate::types::*; -use crate::utils::parallelism::MaybeParallelRefIterator; use crate::{rule::id::Selector, tokenizer::Tokenizer}; use crate::{rule::Rule, Error}; +use crate::{spell::Spell, types::*, utils::parallelism::MaybeParallelRefIterator}; use fs_err::File; use serde::{Deserialize, Serialize}; use std::{ - io::{BufReader, Read}, + io::{BufReader, Read, Write}, path::Path, + sync::Arc, }; -/// Options for a rule set. -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -pub struct RulesOptions {} - /// Language-dependent options for a rule set. #[derive(Debug, Clone, Serialize, Deserialize)] pub(crate) struct RulesLangOptions { @@ -38,45 +34,68 @@ impl Default for RulesLangOptions { } } -/// A set of grammatical error correction rules. #[derive(Serialize, Deserialize, Default)] +struct RulesFields { + pub(crate) rules: Vec, + pub(crate) spell: Spell, +} + +impl From for RulesFields { + fn from(rules: Rules) -> Self { + RulesFields { + rules: rules.rules, + spell: rules.spell, + } + } +} + +/// A set of grammatical error correction rules. +#[derive(Clone, Default, Serialize, Deserialize)] pub struct Rules { pub(crate) rules: Vec, - pub(crate) options: RulesOptions, + pub(crate) spell: Spell, + pub(crate) tokenizer: Arc, } impl Rules { + /// Serializes the rules set to a writer. + pub fn to_writer(&self, writer: &mut W) -> Result<(), Error> { + // TODO: the .clone() here could be avoided + let fields: RulesFields = self.clone().into(); + writer.write_all(&bincode::serialize(&fields)?)?; + Ok(()) + } + + /// Creates a new rules set from a reader. + pub fn from_reader(reader: R, tokenizer: Arc) -> Result { + let fields: RulesFields = bincode::deserialize_from(reader)?; + let rules = Rules { + rules: fields.rules, + spell: fields.spell, + tokenizer, + }; + Ok(rules) + } + /// Creates a new rule set from a path to a binary. /// /// # Errors /// - If the file can not be opened. /// - If the file content can not be deserialized to a rules set. - pub fn new>(p: P) -> Result { - Rules::new_with_options(p, RulesOptions::default()) - } - - /// Creates a new rule set with options. See [new][Rules::new]. - pub fn new_with_options>(p: P, options: RulesOptions) -> Result { + pub fn new>(p: P, tokenizer: Arc) -> Result { let reader = BufReader::new(File::open(p.as_ref())?); - let mut rules: Rules = bincode::deserialize_from(reader)?; - rules.options = options; - Ok(rules) + Self::from_reader(reader, tokenizer) } - /// Gets the options of this rule set. - pub fn options(&self) -> &RulesOptions { - &self.options + /// Gets the spellchecker associated with this rules set. The spellchecker always exists, even if spellchecking is disabled (default). + pub fn spell(&self) -> &Spell { + &self.spell } - /// Gets the options of this rule set (mutable). - pub fn options_mut(&mut self) -> &mut RulesOptions { - &mut self.options - } - - /// Creates a new rules set from a reader. - pub fn from_reader(reader: R) -> Result { - Ok(bincode::deserialize_from(reader)?) + /// Mutably gets the spellchecker. + pub fn spell_mut(&mut self) -> &mut Spell { + &mut self.spell } /// All rules ordered by priority. @@ -106,7 +125,7 @@ impl Rules { } /// Compute the suggestions for the given tokens by checking all rules. - pub fn apply(&self, tokens: &[Token], tokenizer: &Tokenizer) -> Vec { + pub fn apply(&self, tokens: &[Token]) -> Vec { if tokens.is_empty() { return Vec::new(); } @@ -119,7 +138,7 @@ impl Rules { .map(|(i, rule)| { let mut output = Vec::new(); - for suggestion in rule.apply(tokens, tokenizer) { + for suggestion in rule.apply(tokens, self.tokenizer.as_ref()) { output.push((i, suggestion)); } @@ -128,6 +147,8 @@ impl Rules { .flatten() .collect(); + output.extend(self.spell.suggest(tokens).into_iter().map(|x| (0, x))); + output.sort_by(|(ia, a), (ib, b)| a.start.cmp(&b.start).then_with(|| ib.cmp(ia))); let mut mask = vec![false; tokens[0].sentence.chars().count()]; @@ -148,7 +169,7 @@ impl Rules { } /// Compute the suggestions for a text by checking all rules. - pub fn suggest(&self, text: &str, tokenizer: &Tokenizer) -> Vec { + pub fn suggest(&self, text: &str) -> Vec { if text.is_empty() { return Vec::new(); } @@ -157,19 +178,15 @@ impl Rules { let mut char_offset = 0; // get suggestions sentence by sentence - for tokens in tokenizer.pipe(text) { + for tokens in self.tokenizer.pipe(text) { if tokens.is_empty() { continue; } - suggestions.extend( - self.apply(&tokens, tokenizer) - .into_iter() - .map(|mut suggestion| { - suggestion.rshift(char_offset); - suggestion - }), - ); + suggestions.extend(self.apply(&tokens).into_iter().map(|mut suggestion| { + suggestion.rshift(char_offset); + suggestion + })); char_offset += tokens[0].sentence.chars().count(); } @@ -178,26 +195,31 @@ impl Rules { } /// Correct a text by first tokenizing, then finding all suggestions and choosing the first replacement of each suggestion. - pub fn correct(&self, text: &str, tokenizer: &Tokenizer) -> String { - let suggestions = self.suggest(text, tokenizer); + pub fn correct(&self, text: &str) -> String { + let suggestions = self.suggest(text); apply_suggestions(text, &suggestions) } } /// Correct a text by applying suggestions to it. -/// In the case of multiple possible replacements, always chooses the first one. +/// - In case of multiple possible replacements, always chooses the first one. +/// - In case of a suggestion without any replacements, ignores the suggestion. pub fn apply_suggestions(text: &str, suggestions: &[Suggestion]) -> String { let mut offset: isize = 0; let mut chars: Vec<_> = text.chars().collect(); for suggestion in suggestions { - let replacement: Vec<_> = suggestion.replacements[0].chars().collect(); - chars.splice( - (suggestion.start as isize + offset) as usize - ..(suggestion.end as isize + offset) as usize, - replacement.iter().cloned(), - ); - offset = offset + replacement.len() as isize - (suggestion.end - suggestion.start) as isize; + if let Some(replacement) = suggestion.replacements.get(0) { + let replacement_chars: Vec<_> = replacement.chars().collect(); + + chars.splice( + (suggestion.start as isize + offset) as usize + ..(suggestion.end as isize + offset) as usize, + replacement_chars.iter().cloned(), + ); + offset = offset + replacement_chars.len() as isize + - (suggestion.end - suggestion.start) as isize; + } } chars.into_iter().collect() diff --git a/nlprule/src/spell/levenshtein.rs b/nlprule/src/spell/levenshtein.rs new file mode 100644 index 0000000..af66e3a --- /dev/null +++ b/nlprule/src/spell/levenshtein.rs @@ -0,0 +1,134 @@ +use fnv::FnvHasher; +use fst::Automaton; +use std::{ + cmp::{self, min}, + hash::{Hash, Hasher}, +}; + +#[derive(Clone, Debug)] +pub struct LevenshteinState { + dist: usize, + n: usize, + // to compute the next row of the matrix, we also need the row two rows up for transposes + prev_row: Option>, + prev_byte: u8, + row: Vec, + hash: u64, +} + +impl LevenshteinState { + pub fn dist(&self) -> usize { + self.dist + } +} + +#[derive(Debug, Clone)] +pub struct Levenshtein<'a> { + query: &'a [u8], + distance: usize, + prefix: usize, +} + +impl<'a> Levenshtein<'a> { + pub fn new(query: &'a str, distance: usize, prefix: usize) -> Self { + Levenshtein { + query: query.as_bytes(), + distance, + prefix, + } + } +} + +impl<'a> Automaton for Levenshtein<'a> { + type State = Option; + + fn start(&self) -> Self::State { + Some(LevenshteinState { + dist: self.query.len(), + n: 0, + prev_row: None, + prev_byte: 0, + row: (0..=self.query.len()).collect(), + hash: FnvHasher::default().finish(), + }) + } + + fn is_match(&self, state: &Self::State) -> bool { + state + .as_ref() + .map_or(false, |state| state.dist <= self.distance) + } + + fn can_match(&self, state: &Self::State) -> bool { + state.is_some() + } + + fn accept(&self, state: &Self::State, byte: u8) -> Self::State { + state.as_ref().and_then(|state| { + let mut next_hasher = FnvHasher::with_key(state.hash); + byte.hash(&mut next_hasher); + let next_hash = next_hasher.finish(); + + let row = &state.row; + let mut next_row = state.row.to_vec(); + + next_row[0] = state.n + 1; + + for i in 1..next_row.len() { + let mut cost = if byte == self.query[i - 1] { + row[i - 1] + } else { + min( + next_row[i - 1] + 1, // deletes + min( + row[i - 1] + 1, // inserts + row[i] + 1, // substitutes + ), + ) + }; + + if i > 1 { + // transposes + if let Some(prev_row) = state.prev_row.as_ref() { + if byte == self.query[i - 2] && state.prev_byte == self.query[i - 1] { + cost = min(cost, prev_row[i - 2] + 1); + } + } + } + + next_row[i] = cost; + } + + let distance = if state.n >= self.prefix { + self.distance + } else { + 1 + }; + + let lower_bound = state.n.saturating_sub(distance); + let upper_bound = cmp::min(state.n + distance, self.query.len()); + + let cutoff = if lower_bound > upper_bound { + 0 + } else { + *next_row[lower_bound..=upper_bound] + .iter() + .min() + .unwrap_or(&0) + }; + + if cutoff > distance { + return None; + } + + Some(LevenshteinState { + dist: next_row[self.query.len()], + n: state.n + 1, + prev_row: Some(row.clone()), + prev_byte: byte, + row: next_row, + hash: next_hash, + }) + }) + } +} diff --git a/nlprule/src/spell/mod.rs b/nlprule/src/spell/mod.rs new file mode 100644 index 0000000..86de746 --- /dev/null +++ b/nlprule/src/spell/mod.rs @@ -0,0 +1,487 @@ +//! Structures and implementations related to spellchecking. +use fst::{IntoStreamer, Map, MapBuilder, Streamer}; +use serde::{Deserialize, Serialize}; +use std::{ + cmp::Ordering, + collections::{BinaryHeap, HashSet}, + ops::{Deref, DerefMut}, +}; +use unicode_categories::UnicodeCategories; + +use crate::{ + types::*, + utils::{apply_to_first, is_title_case}, + Error, +}; + +mod levenshtein; +mod spell_int { + use std::cmp; + + use serde::{Deserialize, Serialize}; + + /// Encodes information about a valid word in a `u64` for storage as value in an FST. + /// Currently: + /// - the bottom 8 bits encode the frequency + /// - the other 56 bits act as flags for the variants e.g. bit 10 and 12 are set if the word exists in the the second and fourth variant. + #[derive(Debug, Clone, Default, Copy, Serialize, Deserialize)] + pub(crate) struct SpellInt(pub(super) u64); + + type FreqType = u8; + + const fn freq_size() -> usize { + std::mem::size_of::() * 8 + } + + #[allow(dead_code)] // some methods are only needed for compilation - kept here for clarity + impl SpellInt { + pub fn as_u64(&self) -> u64 { + self.0 + } + + pub fn update_freq(&mut self, freq: usize) { + assert!(freq < FreqType::MAX as usize); + + let prev_freq = self.freq(); + // erase previous frequency + self.0 &= u64::MAX - FreqType::MAX as u64; + // set new frequency, strictly speaking we would have to store a frequency for each variant + // but that would need significantly more space, so we just store the highest frequency + self.0 |= cmp::max(prev_freq, freq) as u64; + } + + pub fn add_variant(&mut self, index: usize) { + assert!(index < 64 - freq_size()); + self.0 |= 1 << (freq_size() + index); + } + + pub fn contains_variant(&self, index: usize) -> bool { + (self.0 >> (freq_size() + index)) & 1 == 1 + } + + pub fn freq(&self) -> usize { + (self.0 & FreqType::MAX as u64) as usize + } + } + + #[cfg(test)] + mod tests { + use super::*; + + #[test] + fn can_encode_freq() { + let mut int = SpellInt::default(); + int.update_freq(100); + int.add_variant(1); + int.add_variant(10); + + assert!(int.freq() == 100); + } + + #[test] + fn can_encode_variants() { + let mut int = SpellInt::default(); + int.update_freq(100); + int.add_variant(1); + int.add_variant(10); + int.update_freq(10); + + assert!(int.contains_variant(1)); + assert!(int.contains_variant(10)); + assert!(!int.contains_variant(2)); + assert!(int.freq() == 100); + } + } +} + +pub(crate) use spell_int::SpellInt; + +#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] +struct Candidate { + score: f32, + distance: usize, + freq: usize, + term: String, +} +impl Eq for Candidate {} +impl PartialOrd for Candidate { + fn partial_cmp(&self, other: &Self) -> Option { + // higher score => lower order such that sorting puts highest scores first + other.score.partial_cmp(&self.score) + } +} +impl Ord for Candidate { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.partial_cmp(other).expect("scores are never NaN") + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(default)] +/// Options to configure the spellchecker. +pub struct SpellOptions { + /// The language variant to use. Setting this to `None` disables spellchecking. + pub variant: Option, + /// The maximum edit distance to consider for corrections. Currently Optimal String Alignment distance is used. + pub max_distance: usize, + /// A fixed prefix length for which to consider only edits with a distance of 1. This speeds up the search by pruning the tree early. + pub prefix_length: usize, + /// How high to weigh the frequency of a word compared to the edit distance when ranking correction candidates. + /// Setting this to `x` makes the frequency make a difference of at most `x` edit distance. + pub freq_weight: f32, + /// The maximum number of correction candidates to return. + pub top_n: usize, + /// A set of words to ignore. Can also contain phrases delimited by a space. + pub whitelist: HashSet, +} + +/// A guard around the [SpellOptions]. Makes sure the spellchecker is updated once this is dropped. +/// Implements `Deref` and `DerefMut` to the [SpellOptions]. +pub struct SpellOptionsGuard<'a> { + spell: &'a mut Spell, +} + +impl<'a> Deref for SpellOptionsGuard<'a> { + type Target = SpellOptions; + + fn deref(&self) -> &Self::Target { + &self.spell.options + } +} + +impl<'a> DerefMut for SpellOptionsGuard<'a> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.spell.options + } +} + +impl<'a> Drop for SpellOptionsGuard<'a> { + fn drop(&mut self) { + self.spell.ingest_options() + } +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)] +pub(crate) struct SpellLangOptions { + /// Variants of the language (e.g. "en_US", "en_GB") to consider for spellchecking. + pub variants: Vec, + pub split_hyphens: bool, +} + +impl Default for SpellOptions { + fn default() -> Self { + SpellOptions { + variant: None, + max_distance: 2, + prefix_length: 2, + freq_weight: 2., + top_n: 10, + whitelist: HashSet::new(), + } + } +} + +/// A valid language variant. Obtained by [Spell::variant]. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)] +#[serde(transparent)] +pub struct Variant(String); + +impl Variant { + /// Gets the language code of this variant. + pub fn as_str(&self) -> &str { + self.0.as_str() + } +} + +/// Spellchecker logic for one variant. Does the actual work. +#[derive(Debug, Clone)] +pub(crate) struct VariantChecker { + variant: Variant, + fst: Vec, + max_freq: usize, + multiwords: DefaultHashMap>>, + set: DefaultHashSet, + map: DefaultHashMap, + lang_options: SpellLangOptions, + options: SpellOptions, +} + +impl VariantChecker { + /// Checks the validity of one word. + /// NB: The ordering of this chain of `||` operators is somewhat nontrivial. Could potentially be improved by benchmarking. + /// If this is true, the token is always correct. The converse is not true because e.g. multiwords are checked separately. + fn check_word(&self, word: &str, recurse: bool) -> bool { + word.is_empty() + || self.set.contains(word) + || word + .chars() + .all(|x| x.is_symbol() || x.is_punctuation() || x.is_numeric()) + || (recurse + // for title case words, it is enough if the lowercase variant is known. + // it is possible that `is_title_case` is still true for word where `.to_lowercase()` was called so we need a `recurse` parameter. + && is_title_case(word) + && self.check_word(&apply_to_first(word, |x| x.to_lowercase().collect()), false)) + } + + /// Populates `correct_mask` according to the correctness of the given zeroth token. + /// - `correct_mask[0]` is `true` if the zeroth token is correct, `false` if it is not correct. + /// - Indices `1..n` of `correct_mask` are `true` if the `n`th token is also definitely correct. + /// If they are `false`, they need to be checked separately. + fn check(&self, tokens: &[Token], correct_mask: &mut [bool]) { + let word = tokens[0].word.text.as_ref(); + let mut word_is_correct = self.check_word(word, true); + + if !word_is_correct && self.lang_options.split_hyphens { + // there exist multiple valid hyphens, see https://jkorpela.fi/dashes.html + let hyphens = &['-', '\u{2010}', '\u{2011}'][..]; + + if word.contains(hyphens) && word.split(hyphens).all(|x| self.check_word(x, true)) { + word_is_correct = true; + } + } + + correct_mask[0] = word_is_correct; + + if let Some(continuations) = self.multiwords.get(word) { + if let Some(matching_cont) = continuations.iter().find(|cont| { + // important: an empty continuation matches! so single words can also validly be part of `multiwords` + (tokens.len() - 1) >= cont.len() + && cont + .iter() + .enumerate() + .all(|(i, x)| tokens[i + 1].word.text.as_ref() == x) + }) { + correct_mask[..1 + matching_cont.len()] + .iter_mut() + .for_each(|x| *x = true); + } + } + } + + fn search(&self, word: &str) -> Vec { + if let Some(candidate) = self.map.get(word) { + return vec![candidate.to_owned()]; + } + + let used_fst = Map::new(self.fst.as_slice()).expect("used fst must be valid."); + let query = levenshtein::Levenshtein::new(word, self.options.max_distance, 2); + + let mut out = BinaryHeap::with_capacity(self.options.top_n); + + let mut stream = used_fst.search_with_state(query).into_stream(); + while let Some((k, v, s)) = stream.next() { + let state = s.expect("matching levenshtein state is always `Some`."); + + let id = SpellInt(v); + + let term = String::from_utf8(k.to_vec()).expect("fst keys must be valid utf-8."); + out.push(Candidate { + distance: state.dist(), + freq: id.freq(), + term, + score: (self.options.max_distance - state.dist()) as f32 + + id.freq() as f32 / self.max_freq as f32 * self.options.freq_weight, + }); + if out.len() > self.options.top_n { + out.pop(); + } + } + + // `into_iter_sorted` is unstable - see https://github.com/rust-lang/rust/issues/59278 + out.into_sorted_vec().into_iter().map(|x| x.term).collect() + } +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +/// A spellchecker implementing the algorithm described in [Error-tolerant Finite State Recognition](https://www.aclweb.org/anthology/1995.iwpt-1.24/) with some extensions. +pub struct Spell { + /// An FST mapping valid words (always single tokens!) to a [SpellInt]. + pub(crate) fst: Vec, + /// Known *multiwords* i. e. phrases. Can also validly contain single words if they should not be part of the FST (e.g. words in the whitelist). + pub(crate) multiwords: DefaultHashMap, SpellInt)>>, + /// The maximum occured word frequency. Used to normalize. + pub(crate) max_freq: usize, + /// A map of `wrong->right`. `wrong` must always be exactly one token. + pub(crate) map: DefaultHashMap, + pub(crate) lang_options: SpellLangOptions, + pub(crate) options: SpellOptions, + /// The structure containing the actual spellchecking logic. Computed based on the selected variant. + #[serde(skip)] + pub(crate) variant_checker: Option, +} + +impl Spell { + /// Gets the options. + pub fn options(&self) -> &SpellOptions { + &self.options + } + + /// Mutably gets the options. + pub fn options_mut(&mut self) -> SpellOptionsGuard { + SpellOptionsGuard { spell: self } + } + + /// Returns all known variants. + pub fn variants(&self) -> &[Variant] { + self.lang_options.variants.as_slice() + } + + /// Returns the variant for a language code e.g. `"en_GB"`. + /// # Errors + /// - If no variant exists for the language code. + pub fn variant(&self, variant: &str) -> Result { + self.lang_options + .variants + .iter() + .find(|x| x.as_str() == variant) + .cloned() + .ok_or_else(|| { + Error::UnknownVariant( + variant.to_owned(), + self.lang_options + .variants + .iter() + .map(|x| x.as_str().to_owned()) + .collect(), + ) + }) + } + + pub(crate) fn ingest_options(&mut self) { + let variant = if let Some(variant) = self.options.variant.as_ref() { + variant.clone() + } else { + self.variant_checker = None; + return; + }; + + let variant_index = self + .variants() + .iter() + .position(|x| *x == variant) + .expect("only valid variants are created."); + + let mut checker = match self.variant_checker.take() { + // if the variant checker exists and uses the correct variant, we don't need to rebuild + Some(checker) if checker.variant == variant => checker, + _ => { + let mut used_fst_builder = MapBuilder::memory(); + let mut set = DefaultHashSet::new(); + + let fst = Map::new(&self.fst).expect("serialized fst must be valid."); + let mut stream = fst.into_stream(); + + while let Some((k, v)) = stream.next() { + if SpellInt(v).contains_variant(variant_index) { + set.insert( + String::from_utf8(k.to_vec()).expect("fst keys must be valid utf-8."), + ); + used_fst_builder + .insert(k, v) + .expect("fst stream returns values in lexicographic order."); + } + } + + let fst = used_fst_builder + .into_inner() + .expect("subset of valid fst must be valid."); + + VariantChecker { + variant, + fst, + multiwords: DefaultHashMap::new(), + set, + map: self.map.clone(), + max_freq: self.max_freq, + options: self.options.clone(), + lang_options: self.lang_options.clone(), + } + } + }; + + // `multiwords` depend on the whitelist. For convenience we always rebuild this. + // the whitelist could be separated into a new structure for a speedup. + // We can revisit this if performance becomes an issue, it should still be quite fast as implemented now. + + // selects only the multiwords which exist for the selected variant + let mut multiwords: DefaultHashMap<_, _> = self + .multiwords + .iter() + .map(|(key, value)| { + let value = value + .iter() + .filter_map(|(continuations, int)| { + if int.contains_variant(variant_index) { + Some(continuations) + } else { + None + } + }) + .cloned() + .collect(); + (key.to_owned(), value) + }) + .collect(); + + // adds words from the user-set whitelist + // careful: words in the `whitelist` are set by the user, so this must never fail! + for phrase in self + .options + .whitelist + .iter() + .map(|x| x.as_str()) + // for some important words we have to manually make sure they are ignored :) + .chain(vec!["nlprule", "Minixhofer"]) + { + let mut parts = phrase.trim().split_whitespace(); + + let first = if let Some(first) = parts.next() { + first + } else { + // silently ignore empty words + continue; + }; + + multiwords + .entry(first.to_owned()) + .or_insert_with(Vec::new) + .push(parts.map(|x| x.to_owned()).collect()); + } + + checker.multiwords = multiwords; + self.variant_checker = Some(checker); + } + + /// Runs the spellchecking algorithm on all tokens and returns suggestions. + pub fn suggest(&self, tokens: &[Token]) -> Vec { + let variant_checker = if let Some(checker) = self.variant_checker.as_ref() { + checker + } else { + return Vec::new(); + }; + + let mut suggestions = Vec::new(); + let mut correct_mask = vec![false; tokens.len()]; + + for (i, token) in tokens.iter().enumerate() { + let text = token.word.text.as_ref(); + + if !correct_mask[i] { + variant_checker.check(&tokens[i..], &mut correct_mask[i..]); + } + if correct_mask[i] || token.ignore_spelling { + continue; + } + + suggestions.push(Suggestion { + source: "SPELLCHECK/SINGLE".into(), + message: "Possibly misspelled word.".into(), + start: token.char_span.0, + end: token.char_span.1, + replacements: variant_checker.search(text), + }); + } + + suggestions + } +} diff --git a/nlprule/src/tokenizer.rs b/nlprule/src/tokenizer.rs index 1887636..de5834a 100644 --- a/nlprule/src/tokenizer.rs +++ b/nlprule/src/tokenizer.rs @@ -13,7 +13,7 @@ use crate::{ use fs_err::File; use serde::{Deserialize, Serialize}; use std::{ - io::{BufReader, Read}, + io::{BufReader, Read, Write}, path::Path, sync::Arc, }; @@ -128,6 +128,11 @@ impl Tokenizer { Ok(bincode::deserialize_from(reader)?) } + /// Serializes the tokenizer to a writer. + pub fn to_writer(&self, writer: &mut W) -> Result<(), Error> { + Ok(bincode::serialize_into(writer, &self)?) + } + /// Gets all disambigation rules in the order they are applied. pub fn rules(&self) -> &[DisambiguationRule] { &self.rules @@ -196,7 +201,7 @@ impl Tokenizer { self.disambiguate_up_to_id(tokens, None) } - fn get_token_strs<'t>(&self, text: &'t str) -> Vec<&'t str> { + pub(crate) fn get_token_strs<'t>(&self, text: &'t str) -> Vec<&'t str> { let mut tokens = Vec::new(); let split_char = |c: char| c.is_whitespace() || crate::utils::splitting_chars().contains(c); @@ -274,6 +279,7 @@ impl Tokenizer { char_span: (char_start, current_char), byte_span: (byte_start, byte_start + x.len()), is_sentence_end, + ignore_spelling: false, has_space_before: sentence[..byte_start].ends_with(char::is_whitespace), chunks: Vec::new(), multiword_data: None, diff --git a/nlprule/src/tokenizer/multiword.rs b/nlprule/src/tokenizer/multiword.rs index e3e9bfc..2694524 100644 --- a/nlprule/src/tokenizer/multiword.rs +++ b/nlprule/src/tokenizer/multiword.rs @@ -70,6 +70,7 @@ impl MultiwordTagger { tagger.id_word(word.as_str().into()), pos.as_ref_id(), )); + token.ignore_spelling = true; } } } diff --git a/nlprule/src/tokenizer/tag.rs b/nlprule/src/tokenizer/tag.rs index 568aa6e..dd3879c 100644 --- a/nlprule/src/tokenizer/tag.rs +++ b/nlprule/src/tokenizer/tag.rs @@ -9,7 +9,7 @@ use log::error; use serde::{Deserialize, Serialize}; use std::{borrow::Cow, iter::once}; -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Default, Serialize, Deserialize)] pub(crate) struct TaggerLangOptions { /// Whether to use a heuristic to split potential compound words. pub use_compound_split_heuristic: bool, @@ -19,16 +19,6 @@ pub(crate) struct TaggerLangOptions { pub extra_tags: Vec, } -impl Default for TaggerLangOptions { - fn default() -> Self { - TaggerLangOptions { - use_compound_split_heuristic: false, - always_add_lower_tags: false, - extra_tags: Vec::new(), - } - } -} - #[derive(Serialize, Deserialize)] struct TaggerFields { tag_fst: Vec, @@ -214,13 +204,13 @@ impl Tagger { &self.word_store } - fn str_for_word_id(&self, id: &WordIdInt) -> &str { + pub(crate) fn str_for_word_id(&self, id: &WordIdInt) -> &str { self.word_store .get_by_right(id) .expect("only valid word ids are created") } - fn str_for_pos_id(&self, id: &PosIdInt) -> &str { + pub(crate) fn str_for_pos_id(&self, id: &PosIdInt) -> &str { self.tag_store .get_by_right(id) .expect("only valid pos ids are created") diff --git a/nlprule/src/types.rs b/nlprule/src/types.rs index a49572b..68104bb 100644 --- a/nlprule/src/types.rs +++ b/nlprule/src/types.rs @@ -13,9 +13,11 @@ pub(crate) type DefaultHashMap = HashMap; pub(crate) type DefaultHashSet = HashSet; pub(crate) type DefaultHasher = hash_map::DefaultHasher; -#[derive(Debug, Copy, Clone, Serialize, Deserialize, Hash, Eq, PartialEq, Ord, PartialOrd)] -#[serde(transparent)] +#[derive( + Debug, Copy, Clone, Serialize, Deserialize, Hash, Eq, PartialEq, Ord, PartialOrd, Default, +)] pub(crate) struct WordIdInt(pub u32); + #[derive(Debug, Copy, Clone, Serialize, Deserialize, Hash, Eq, PartialEq, Ord, PartialOrd)] #[serde(transparent)] pub(crate) struct PosIdInt(pub u16); @@ -196,7 +198,7 @@ pub struct IncompleteToken<'t> { pub byte_span: (usize, usize), /// Char start (inclusive) and end (exclusive) of this token in the sentence. pub char_span: (usize, usize), - /// Whether this token is the last token in the sentence- + /// Whether this token is the last token in the sentence. pub is_sentence_end: bool, /// Whether this token has one or more whitespace characters before. pub has_space_before: bool, @@ -204,6 +206,8 @@ pub struct IncompleteToken<'t> { pub chunks: Vec, /// A *multiword* lemma and part-of-speech tag. Set if the token was found in a list of phrases. pub multiword_data: Option>, + /// Whether to ignore spelling for this token. + pub ignore_spelling: bool, /// The sentence this token is in. pub sentence: &'t str, /// The tagger used for lookup related to this token. @@ -225,6 +229,7 @@ pub struct Token<'t> { pub word: Word<'t>, pub char_span: (usize, usize), pub byte_span: (usize, usize), + pub ignore_spelling: bool, pub has_space_before: bool, pub chunks: Vec, pub sentence: &'t str, @@ -247,6 +252,7 @@ impl<'t> Token<'t> { ), char_span: (0, 0), byte_span: (0, 0), + ignore_spelling: true, has_space_before: false, chunks: Vec::new(), sentence, @@ -296,6 +302,7 @@ impl<'t> From> for Token<'t> { word, byte_span: data.byte_span, char_span: data.char_span, + ignore_spelling: data.ignore_spelling, has_space_before: data.has_space_before, chunks: data.chunks, sentence: data.sentence, diff --git a/nlprule/tests/tests.rs b/nlprule/tests/tests.rs index 270a17d..7cca1d5 100644 --- a/nlprule/tests/tests.rs +++ b/nlprule/tests/tests.rs @@ -1,15 +1,15 @@ -use std::convert::TryInto; +use std::{convert::TryInto, sync::Arc}; use lazy_static::lazy_static; -use nlprule::{rule::id::Category, Rules, Tokenizer}; +use nlprule::{rule::id::Category, Error, Rules, Tokenizer}; use quickcheck_macros::quickcheck; const TOKENIZER_PATH: &str = "../storage/en_tokenizer.bin"; const RULES_PATH: &str = "../storage/en_rules.bin"; lazy_static! { - static ref TOKENIZER: Tokenizer = Tokenizer::new(TOKENIZER_PATH).unwrap(); - static ref RULES: Rules = Rules::new(RULES_PATH).unwrap(); + static ref TOKENIZER: Arc = Arc::new(Tokenizer::new(TOKENIZER_PATH).unwrap()); + static ref RULES: Rules = Rules::new(RULES_PATH, TOKENIZER.clone()).unwrap(); } #[test] @@ -25,12 +25,10 @@ fn can_tokenize_anything(text: String) -> bool { #[test] fn rules_can_be_disabled_enabled() { - let mut rules = Rules::new(RULES_PATH).unwrap(); + let mut rules = Rules::new(RULES_PATH, TOKENIZER.clone()).unwrap(); // enabled by default - assert!(!rules - .suggest("I can due his homework", &*TOKENIZER) - .is_empty()); + assert!(!rules.suggest("I can due his homework").is_empty()); rules .select_mut( @@ -41,17 +39,25 @@ fn rules_can_be_disabled_enabled() { .for_each(|x| x.disable()); // disabled now - assert!(rules - .suggest("I can due his homework", &*TOKENIZER) - .is_empty()); + assert!(rules.suggest("I can due his homework").is_empty()); // disabled by default - assert!(rules.suggest("I can not go", &*TOKENIZER).is_empty()); + assert!(rules.suggest("I can not go").is_empty()); rules .select_mut(&"typos/can_not".try_into().unwrap()) .for_each(|x| x.enable()); // enabled now - assert!(!rules.suggest("I can not go", &*TOKENIZER).is_empty()); + assert!(!rules.suggest("I can not go").is_empty()); +} + +#[test] +fn spellchecker_works() -> Result<(), Error> { + let mut rules = Rules::new(RULES_PATH, TOKENIZER.clone()).unwrap(); + rules.spell_mut().options_mut().variant = Some(rules.spell().variant("en_GB")?); + + assert_eq!(rules.correct("color spellhceking"), "colour spellchecking"); + + Ok(()) } diff --git a/python/Cargo.toml b/python/Cargo.toml index 6979312..539f3d9 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -15,6 +15,7 @@ parking_lot = { version = "0.11", features = ["serde"] } reqwest = { version = "0.11", default_features = false, features = ["blocking", "rustls-tls"]} flate2 = "1" directories = "3" +pythonize = "0.13" syn = "=1.0.57" # workaround for "could not find `export` in `syn`" by enum_dispatch nlprule = { path = "../nlprule" } # BUILD_BINDINGS_COMMENT # nlprule = { package = "nlprule-core", path = "../nlprule" } # BUILD_BINDINGS_UNCOMMENT diff --git a/python/src/lib.rs b/python/src/lib.rs index 54084cb..fd1b905 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -2,6 +2,7 @@ use flate2::read::GzDecoder; use nlprule::{ rule::{id::Selector, Example, Rule}, rules::{apply_suggestions, Rules}, + spell::Spell, tokenizer::tag::Tagger, tokenizer::Tokenizer, types::*, @@ -9,10 +10,16 @@ use nlprule::{ use parking_lot::{ MappedRwLockReadGuard, MappedRwLockWriteGuard, RwLock, RwLockReadGuard, RwLockWriteGuard, }; -use pyo3::prelude::*; -use pyo3::types::PyString; use pyo3::{exceptions::PyValueError, types::PyBytes}; +use pyo3::{ + prelude::*, + types::{PyDict, PyFrozenSet}, + wrap_pymodule, +}; +use pyo3::{types::PyString, ToPyObject}; +use pythonize::depythonize; use std::{ + collections::HashSet, convert::TryFrom, fs, io::{Cursor, Read}, @@ -20,6 +27,10 @@ use std::{ sync::Arc, }; +fn err(error: nlprule::Error) -> PyErr { + PyValueError::new_err(format!("{}", error)) +} + fn get_resource(lang_code: &str, name: &str) -> PyResult { let version = env!("CARGO_PKG_VERSION"); let mut cache_path: Option = None; @@ -306,40 +317,42 @@ impl From for PySuggestion { /// When created from a language code, the binary is downloaded from the internet the first time. /// Then it is stored at your cache and loaded from there. #[pyclass(name = "Tokenizer", module = "nlprule")] -#[text_signature = "(path, sentence_splitter=None)"] +#[text_signature = "(path)"] #[derive(Default)] pub struct PyTokenizer { - tokenizer: Tokenizer, + tokenizer: Arc, } impl PyTokenizer { - fn tokenizer(&self) -> &Tokenizer { + fn tokenizer(&self) -> &Arc { &self.tokenizer } } #[pymethods] impl PyTokenizer { - #[text_signature = "(code, sentence_splitter=None)"] + #[text_signature = "(code)"] #[staticmethod] fn load(lang_code: &str) -> PyResult { let bytes = get_resource(lang_code, "tokenizer.bin.gz")?; - let tokenizer: Tokenizer = bincode::deserialize_from(bytes) - .map_err(|x| PyValueError::new_err(format!("{}", x)))?; - Ok(PyTokenizer { tokenizer }) + let tokenizer = Tokenizer::from_reader(bytes).map_err(err)?; + Ok(PyTokenizer { + tokenizer: Arc::new(tokenizer), + }) } #[new] fn new(path: Option<&str>) -> PyResult { let tokenizer = if let Some(path) = path { - Tokenizer::new(path) - .map_err(|x| PyValueError::new_err(format!("error creating Tokenizer: {}", x)))? + Tokenizer::new(path).map_err(err)? } else { Tokenizer::default() }; - Ok(PyTokenizer { tokenizer }) + Ok(PyTokenizer { + tokenizer: Arc::new(tokenizer), + }) } /// Get the tagger dictionary of this tokenizer. @@ -402,8 +415,8 @@ impl PyTokenizer { } } -impl From for PyTokenizer { - fn from(tokenizer: Tokenizer) -> Self { +impl From> for PyTokenizer { + fn from(tokenizer: Arc) -> Self { PyTokenizer { tokenizer } } } @@ -474,7 +487,7 @@ impl PyRule { RwLockWriteGuard::map(self.rules.write(), |x| &mut x.rules_mut()[self.index]) } - fn from_rule(index: usize, rules: Arc>) -> Self { + fn from_rules(index: usize, rules: Arc>) -> Self { PyRule { rules, index } } } @@ -536,6 +549,144 @@ impl PyRule { } } +#[pyclass(name = "SpellOptions", module = "nlprule.spell")] +struct PySpellOptions { + rules: Arc>, +} + +impl PySpellOptions { + fn spell(&self) -> MappedRwLockReadGuard<'_, Spell> { + RwLockReadGuard::map(self.rules.read(), |x| x.spell()) + } + + fn spell_mut(&self) -> MappedRwLockWriteGuard<'_, Spell> { + RwLockWriteGuard::map(self.rules.write(), |x| x.spell_mut()) + } +} + +#[pymethods] +impl PySpellOptions { + #[getter] + fn get_variant(&self) -> Option { + self.spell() + .options() + .variant + .as_ref() + .map(|x| x.as_str().to_owned()) + } + + #[setter] + fn set_variant(&self, variant: Option<&str>) -> PyResult<()> { + if let Some(variant) = variant { + let mut spell = self.spell_mut(); + let variant = spell.variant(variant).map_err(err)?; + + spell.options_mut().variant = Some(variant); + } else { + self.spell_mut().options_mut().variant = None; + } + + Ok(()) + } + + #[getter] + fn get_max_distance(&self) -> usize { + self.spell().options().max_distance + } + + #[setter] + fn set_max_distance(&self, max_distance: usize) { + self.spell_mut().options_mut().max_distance = max_distance + } + + #[getter] + fn get_prefix_length(&self) -> usize { + self.spell().options().prefix_length + } + + #[setter] + fn set_prefix_length(&self, prefix_length: usize) { + self.spell_mut().options_mut().prefix_length = prefix_length + } + + #[getter] + fn get_freq_weight(&self) -> f32 { + self.spell().options().freq_weight + } + + #[setter] + fn set_freq_weight(&self, freq_weight: f32) { + self.spell_mut().options_mut().freq_weight = freq_weight + } + + #[getter] + fn get_top_n(&self) -> usize { + self.spell().options().top_n + } + + #[setter] + fn set_top_n(&self, top_n: usize) { + self.spell_mut().options_mut().top_n = top_n + } + + #[getter] + fn get_whitelist<'py>(&self, py: Python<'py>) -> PyResult<&'py PyFrozenSet> { + let spell = self.spell(); + let whitelist: Vec<&str> = spell + .options() + .whitelist + .iter() + .map(|x| x.as_str()) + .collect(); + + PyFrozenSet::new(py, &whitelist) + } + + #[setter] + fn set_whitelist(&self, py: Python, whitelist: PyObject) -> PyResult<()> { + let whitelist: PyResult> = whitelist + .as_ref(py) + .iter()? + .map(|x| x.and_then(PyAny::extract::)) + .collect(); + self.spell_mut().options_mut().whitelist = whitelist?; + Ok(()) + } +} + +#[pyclass(name = "Spell", module = "nlprule.spell")] +struct PySpell { + rules: Arc>, +} + +#[pymethods] +impl PySpell { + #[getter] + fn variants(&self) -> Vec { + self.rules + .read() + .spell() + .variants() + .iter() + .map(|x| x.as_str().to_owned()) + .collect() + } + + #[getter] + fn get_options(&self) -> PySpellOptions { + PySpellOptions { + rules: self.rules.clone(), + } + } + + #[setter] + fn set_options(&self, py: Python, options: &PyDict) -> PyResult<()> { + let mut guard = self.rules.write(); + *guard.spell_mut().options_mut() = depythonize(options.to_object(py).as_ref(py))?; + Ok(()) + } +} + /// The grammatical rules. /// Can be created from a rules binary: /// ```python @@ -549,47 +700,50 @@ impl PyRule { /// When created from a language code, the binary is downloaded from the internet the first time. /// Then it is stored at your cache and loaded from there. #[pyclass(name = "Rules", module = "nlprule")] -#[text_signature = "(path, tokenizer, sentence_splitter=None)"] +#[text_signature = "(path, tokenizer)"] struct PyRules { rules: Arc>, - tokenizer: Py, } #[pymethods] impl PyRules { - #[text_signature = "(code, tokenizer, sentence_splitter=None)"] + #[text_signature = "(code, tokenizer)"] #[staticmethod] - fn load(lang_code: &str, tokenizer: Py) -> PyResult { + fn load(lang_code: &str, tokenizer: &PyTokenizer) -> PyResult { let bytes = get_resource(lang_code, "rules.bin.gz")?; - let rules: Rules = bincode::deserialize_from(bytes) - .map_err(|x| PyValueError::new_err(format!("{}", x)))?; + let rules = Rules::from_reader(bytes, tokenizer.tokenizer().clone()).map_err(err)?; Ok(PyRules { rules: Arc::from(RwLock::from(rules)), - tokenizer, }) } #[new] - fn new(py: Python, path: Option<&str>, tokenizer: Option>) -> PyResult { - let rules = if let Some(path) = path { - Rules::new(path) - .map_err(|x| PyValueError::new_err(format!("error creating Rules: {}", x)))? + fn new(path: Option<&str>, tokenizer: Option<&PyTokenizer>) -> PyResult { + let tokenizer = if let Some(tokenizer) = tokenizer { + tokenizer.tokenizer().clone() } else { - Rules::default() + PyTokenizer::default().tokenizer().clone() }; - let tokenizer = if let Some(tokenizer) = tokenizer { - tokenizer + + let rules = if let Some(path) = path { + Rules::new(path, tokenizer).map_err(err)? } else { - Py::new(py, PyTokenizer::default())? + Rules::default() }; Ok(PyRules { rules: Arc::from(RwLock::from(rules)), - tokenizer, }) } + #[getter] + fn spell(&self) -> PySpell { + PySpell { + rules: self.rules.clone(), + } + } + #[getter] fn rules(&self) -> Vec { self.rules @@ -597,14 +751,14 @@ impl PyRules { .rules() .iter() .enumerate() - .map(|(i, _)| PyRule::from_rule(i, self.rules.clone())) + .map(|(i, _)| PyRule::from_rules(i, self.rules.clone())) .collect() } /// Finds a rule by selector. fn select(&self, id: &str) -> PyResult> { let selector = Selector::try_from(id.to_owned()) - .map_err(|err| PyValueError::new_err(format!("error creating selector: {}", err)))?; + .map_err(|err| PyValueError::new_err(format!("{}", err)))?; Ok(self .rules @@ -613,7 +767,7 @@ impl PyRules { .iter() .enumerate() .filter(|(_, rule)| selector.is_match(rule.id())) - .map(|(i, _)| PyRule::from_rule(i, self.rules.clone())) + .map(|(i, _)| PyRule::from_rules(i, self.rules.clone())) .collect()) } @@ -625,15 +779,12 @@ impl PyRules { /// Returns: /// suggestions (Union[List[Suggestion], List[List[Suggestion]]]): /// The computed suggestions. Batched if the input is batched. - #[text_signature = "(sentence_or_sentences)"] - fn suggest(&self, py: Python, sentence_or_sentences: PyObject) -> PyResult { - text_guard(py, sentence_or_sentences, |sentence| { - let tokenizer = self.tokenizer.borrow(py); - let tokenizer = tokenizer.tokenizer(); - + #[text_signature = "(text_or_texts)"] + fn suggest(&self, py: Python, text_or_texts: PyObject) -> PyResult { + text_guard(py, text_or_texts, |text| { self.rules .read() - .suggest(&sentence, &tokenizer) + .suggest(&text) .into_iter() .map(|x| PyCell::new(py, PySuggestion::from(x))) .collect::>>() @@ -651,15 +802,13 @@ impl PyRules { #[text_signature = "(text_or_texts)"] fn correct(&self, py: Python, text_or_texts: PyObject) -> PyResult { text_guard(py, text_or_texts, |text| { - let tokenizer = self.tokenizer.borrow(py); - let tokenizer = tokenizer.tokenizer(); - - Ok(self.rules.read().correct(&text, tokenizer)) + Ok(self.rules.read().correct(&text)) }) } - /// Convenience method to apply suggestions to the given text. - /// Always uses the first element of `suggestion.replacements` as replacement. + /// Correct a text by applying suggestions to it. + /// - In case of multiple possible replacements, always chooses the first one. + /// - In case of a suggestion without any replacements, ignores the suggestion. /// /// Arguments: /// text (str): The input text. @@ -691,13 +840,11 @@ impl PyRules { pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { match state.extract::<&PyBytes>(py) { Ok(s) => { - let state: (Rules, Tokenizer) = - bincode::deserialize(s.as_bytes()).map_err(|_| { - PyValueError::new_err("deserializing state with `bincode` failed") - })?; + let rules: Rules = bincode::deserialize(s.as_bytes()).map_err(|_| { + PyValueError::new_err("deserializing state with `bincode` failed") + })?; // a roundtrip through pickle can not preserve references so we need to create a new Arc> - self.rules = Arc::from(RwLock::from(state.0)); - self.tokenizer = Py::new(py, PyTokenizer::from(state.1))?; + self.rules = Arc::new(RwLock::new(rules)); Ok(()) } Err(e) => Err(e), @@ -705,26 +852,33 @@ impl PyRules { } pub fn __getstate__(&self, py: Python) -> PyResult { - let tokenizer = self.tokenizer.borrow(py); - // rwlock is serialized the same way as the inner type - let state = (&self.rules, tokenizer.tokenizer()); - Ok(PyBytes::new( py, - &bincode::serialize(&state) + // rwlock serialization is transparent + &bincode::serialize(&self.rules) .map_err(|_| PyValueError::new_err("serializing state with `bincode` failed"))?, ) .to_object(py)) } } +#[pymodule] +fn spell(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + Ok(()) +} + #[pymodule] fn nlprule(_py: Python, m: &PyModule) -> PyResult<()> { m.add("__version__", env!("CARGO_PKG_VERSION"))?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_wrapped(wrap_pymodule!(spell))?; + Ok(()) } diff --git a/python/test.py b/python/test.py index 086b1a4..799216d 100644 --- a/python/test.py +++ b/python/test.py @@ -138,3 +138,24 @@ def test_rules_can_be_disabled(tokenizer_and_rules): rule.disable() assert len(rules.suggest("I can due his homework")) == 0 + +def test_spell_options_can_be_read(tokenizer_and_rules): + (tokenizer, rules) = tokenizer_and_rules + + assert rules.spell.options.max_distance > 0 + assert rules.spell.options.variant is None + +def test_spell_options_can_be_set(tokenizer_and_rules): + (tokenizer, rules) = tokenizer_and_rules + + with pytest.raises(ValueError): + rules.spell.options.variant = "en_INVALID" + + rules.spell.options.variant = "en_GB" + assert rules.spell.options.variant == "en_GB" + +def test_spellchecker_works(tokenizer_and_rules): + (tokenizer, rules) = tokenizer_and_rules + + # TODO + # print(rules.spell.search("lämp")) \ No newline at end of file diff --git a/scripts/build_and_test.sh b/scripts/build_and_test.sh index 6f43f9f..d060156 100755 --- a/scripts/build_and_test.sh +++ b/scripts/build_and_test.sh @@ -1,6 +1,23 @@ # this script assumes the build directories are in data/ # only for convenience mkdir -p storage -RUST_LOG=INFO cargo run --all-features --bin compile -- --build-dir data/$1 --tokenizer-out storage/$1_tokenizer.bin --rules-out storage/$1_rules.bin -RUST_LOG=WARN cargo run --all-features --bin test_disambiguation -- --tokenizer storage/$1_tokenizer.bin -RUST_LOG=WARN cargo run --all-features --bin test -- --tokenizer storage/$1_tokenizer.bin --rules storage/$1_rules.bin \ No newline at end of file + +# x-- => only compile +# -xx => test_disambiguation and test +# xxx or flags not set => everything +flags=${2:-"xxx"} + +if [ "${flags:0:1}" == "x" ] +then + RUST_LOG=INFO cargo run --all-features --bin compile -- --build-dir data/$1 --tokenizer-out storage/$1_tokenizer.bin --rules-out storage/$1_rules.bin +fi + +if [ "${flags:1:1}" == "x" ] +then + RUST_LOG=WARN cargo run --all-features --bin test_disambiguation -- --tokenizer storage/$1_tokenizer.bin +fi + +if [ "${flags:2:1}" == "x" ] +then + RUST_LOG=WARN cargo run --all-features --bin test -- --tokenizer storage/$1_tokenizer.bin --rules storage/$1_rules.bin +fi \ No newline at end of file diff --git a/scripts/maturin.sh b/scripts/maturin.sh index 7fa1844..cc322dd 100755 --- a/scripts/maturin.sh +++ b/scripts/maturin.sh @@ -22,13 +22,27 @@ build_change build/Cargo.toml build_change Cargo.toml cd python + +trap ctrl_c INT + +function ctrl_c() { + cleanup + exit +} + +function cleanup() { + # this is a bit hacky, assume we are in python/ dir + cd .. + + mv python/.Cargo.toml.bak python/Cargo.toml + mv nlprule/.Cargo.toml.bak nlprule/Cargo.toml + mv build/.Cargo.toml.bak build/Cargo.toml + mv .Cargo.toml.bak Cargo.toml +} + maturin $@ exit_code=$? -cd .. -mv python/.Cargo.toml.bak python/Cargo.toml -mv nlprule/.Cargo.toml.bak nlprule/Cargo.toml -mv build/.Cargo.toml.bak build/Cargo.toml -mv .Cargo.toml.bak Cargo.toml +cleanup exit $exit_code \ No newline at end of file