Skip to content

Commit 667d35e

Browse files
Add synthetic dataset and tree integrity tests
1 parent de5d67a commit 667d35e

File tree

7 files changed

+169
-18
lines changed

7 files changed

+169
-18
lines changed

Cargo.toml

+4
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ path = "examples/anomaly_detection/credit_card.rs"
3535
name = "keystroke"
3636
path = "examples/classification/keystroke.rs"
3737

38+
[[example]]
39+
name = "synthetic"
40+
path = "examples/classification/synthetic.rs"
41+
3842
[[bench]]
3943
name = "hst"
4044
harness = false

examples/classification/keystroke.rs

+7-6
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,14 @@ fn main() {
4444
let n_trees: usize = 1;
4545

4646
let transactions_f = Keystroke::load_data().unwrap();
47-
let mut features = get_features(transactions_f);
47+
let features = get_features(transactions_f);
4848
// DEBUG: remove it
49-
features = features[0..2].to_vec();
49+
// let features = features[0..2].to_vec();
5050

5151
let transactions_c = Keystroke::load_data().unwrap();
52-
let mut labels = get_labels(transactions_c);
53-
labels = labels[0..3].to_vec();
52+
let labels = get_labels(transactions_c);
53+
// DEBUG: remove it
54+
// let labels = labels[0..3].to_vec();
5455
println!("labels: {labels:?}");
5556
let mut mf: MondrianForest<f32> = MondrianForest::new(window_size, n_trees, &features, &labels);
5657

@@ -66,9 +67,9 @@ fn main() {
6667
ClassifierTarget::String(y) => y,
6768
_ => unimplemented!(),
6869
};
69-
let mut x_ord = Array1::<f32>::from_vec(features.iter().map(|k| x[k]).collect());
70+
let x_ord = Array1::<f32>::from_vec(features.iter().map(|k| x[k]).collect());
7071
// DEBUG: remove it
71-
x_ord = x_ord.slice(s![0..2]).to_owned();
72+
// let x_ord = x_ord.slice(s![0..2]).to_owned();
7273

7374
println!("=M=1 partial_fit");
7475
mf.partial_fit(&x_ord, &y);

examples/classification/synthetic.rs

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
use light_river::classification::alias::FType;
2+
use light_river::classification::mondrian_forest::MondrianForest;
3+
use light_river::classification::mondrian_tree::MondrianTree;
4+
use light_river::common::ClassifierOutput;
5+
use light_river::common::ClassifierTarget;
6+
use light_river::datasets::synthetic::Synthetic;
7+
use light_river::metrics::rocauc::ROCAUC;
8+
use light_river::metrics::traits::ClassificationMetric;
9+
use light_river::stream::data_stream::DataStream;
10+
use light_river::stream::iter_csv::IterCsv;
11+
use ndarray::{s, Array1};
12+
use std::borrow::Borrow;
13+
use std::fs::File;
14+
use std::time::Instant;
15+
16+
/// Get list of features of the dataset.
17+
///
18+
/// e.g. features: ["H.e", "UD.t.i", "H.i", ...]
19+
fn get_features(transactions: IterCsv<f32, File>) -> Vec<String> {
20+
let sample = transactions.into_iter().next();
21+
let observation = sample.unwrap().unwrap().get_observation();
22+
observation.iter().map(|(k, _)| k.clone()).collect()
23+
}
24+
25+
fn get_labels(transactions: IterCsv<f32, File>) -> Vec<String> {
26+
let mut labels = vec![];
27+
for t in transactions {
28+
let data = t.unwrap();
29+
// TODO: use instead 'to_classifier_target' and a vector of 'ClassifierTarget'
30+
let target = data.get_y().unwrap()["label"].to_string();
31+
if !labels.contains(&target) {
32+
labels.push(target);
33+
}
34+
}
35+
labels
36+
}
37+
38+
fn main() {
39+
let now = Instant::now();
40+
let window_size: usize = 1000;
41+
let n_trees: usize = 1;
42+
43+
let transactions_f = Synthetic::load_data().unwrap();
44+
let features = get_features(transactions_f);
45+
// DEBUG: remove it
46+
// let features = features[0..2].to_vec();
47+
48+
let transactions_c = Synthetic::load_data().unwrap();
49+
let labels = get_labels(transactions_c);
50+
// DEBUG: remove it
51+
// let labels = labels[0..3].to_vec();
52+
println!("labels: {labels:?}");
53+
let mut mf: MondrianForest<f32> = MondrianForest::new(window_size, n_trees, &features, &labels);
54+
55+
let transactions = Synthetic::load_data().unwrap();
56+
for transaction in transactions {
57+
let data = transaction.unwrap();
58+
59+
let x = data.get_observation();
60+
let y = data.to_classifier_target("label").unwrap();
61+
// TODO: generalize to non-classification only by implementing 'ClassifierTarget'
62+
// instead of taking directly the string.
63+
let y = match y {
64+
ClassifierTarget::String(y) => y,
65+
_ => unimplemented!(),
66+
};
67+
let x_ord = Array1::<f32>::from_vec(features.iter().map(|k| x[k]).collect());
68+
// DEBUG: remove it
69+
// let x_ord = x_ord.slice(s![0..2]).to_owned();
70+
71+
println!("=M=1 partial_fit {x_ord}");
72+
mf.partial_fit(&x_ord, &y);
73+
74+
println!("=M=2 predict_proba");
75+
let score = mf.predict_proba(&x_ord);
76+
77+
println!("=M=3 score: {:?}", score);
78+
println!("");
79+
}
80+
81+
let elapsed_time = now.elapsed();
82+
println!("Took {}ms", elapsed_time.as_millis());
83+
// println!("ROCAUC: {:.2}%", roc_auc.get() * (100.0 as f32));
84+
}

src/classification/mondrian_forest.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ impl<F: FType> MondrianForest<F> {
4343
///
4444
/// Function in River/LightRiver: "learn_one()"
4545
pub fn partial_fit(&mut self, x: &Array1<F>, y: &String) {
46-
println!("partial_fit() - x: {:?}, y: {y:?}", x.to_vec());
4746
for tree in &mut self.trees {
4847
tree.partial_fit(x, y);
48+
println!("treeee {}", tree);
4949
}
5050
}
5151

src/classification/mondrian_tree.rs

+52-11
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use rand::prelude::*;
1313
use rand_distr::{Distribution, Exp};
1414
use std::cell::RefCell;
1515
use std::collections::HashMap;
16+
use std::collections::HashSet;
1617
use std::convert::TryFrom;
1718
use std::env::consts;
1819
use std::fmt;
@@ -31,6 +32,7 @@ pub struct MondrianTree<F: FType> {
3132
rng: ThreadRng,
3233
first_learn: bool,
3334
nodes: Vec<Node<F>>,
35+
root: Option<usize>,
3436
}
3537
impl<F: FType + fmt::Display> fmt::Display for MondrianTree<F> {
3638
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
@@ -39,22 +41,23 @@ impl<F: FType + fmt::Display> fmt::Display for MondrianTree<F> {
3941
write!(f, "│ window_size: {}", self.window_size)?;
4042
for (i, node) in self.nodes.iter().enumerate() {
4143
writeln!(f)?;
42-
write!(f, "│ │ Node {}: left = {:?}, right = {:?}, parent = {:?}, tau = {}, is_leaf = {}, min = {:?}, max = {:?}", i, node.left, node.right, node.parent, node.tau, node.is_leaf, node.min_list.to_vec(), node.max_list.to_vec())?;
44+
// write!(f, "│ │ Node {}: left = {:?}, right = {:?}, parent = {:?}, tau = {}, min = {:?}, max = {:?}", i, node.left, node.right, node.parent, node.tau, node.min_list.to_vec(), node.max_list.to_vec())?;
45+
write!(f, "│ │ Node {}: left={:?}, right={:?}, parent={:?}, tau={}, is_leaf={}, min={:?}, max={:?}", i, node.left, node.right, node.parent, node.tau, node.is_leaf, node.min_list.to_vec(), node.max_list.to_vec())?;
46+
// write!(f, "│ │ Node {}: left={:?}, right={:?}, parent={:?}, tau={}, min={:?}, max={:?}", i, node.left, node.right, node.parent, node.tau, node.min_list.to_vec(), node.max_list.to_vec())?;
4347
}
4448
Ok(())
4549
}
4650
}
4751
impl<F: FType> MondrianTree<F> {
4852
pub fn new(window_size: usize, features: &Vec<String>, labels: &Vec<String>) -> Self {
49-
let mut rng = rand::thread_rng();
50-
let nodes = vec![];
5153
MondrianTree::<F> {
5254
window_size,
5355
features: features.clone(),
5456
labels: labels.clone(),
55-
rng,
57+
rng: rand::thread_rng(),
5658
first_learn: false,
57-
nodes,
59+
nodes: vec![],
60+
root: None,
5861
}
5962
}
6063

@@ -79,6 +82,7 @@ impl<F: FType> MondrianTree<F> {
7982
right: None,
8083
stats: Stats::new(num_labels, feature_dim),
8184
};
85+
8286
// TODO: check if this works:
8387
// labels: ["s002", "s003", "s004"]
8488
let label_idx = labels.iter().position(|l| l == label).unwrap();
@@ -92,10 +96,43 @@ impl<F: FType> MondrianTree<F> {
9296
/// working only on one, so it's the same as "predict()".
9397
pub fn predict_proba(&self, x: &Array1<F>) -> Array1<F> {
9498
let root = 0;
99+
self.test_tree();
95100
self.predict(x, root, F::one())
96101
}
97102

98-
fn extend_mondrian_block(&mut self, node_idx: usize, x: &Array1<F>, label: &String) {
103+
fn test_tree(&self) {
104+
for node_idx in 0..self.nodes.len() {
105+
// TODO: check if self.root is None, if so tree should be empty
106+
if node_idx == self.root.unwrap() {
107+
// Root node
108+
assert!(self.nodes[node_idx].parent.is_none(), "Root has a parent.");
109+
} else {
110+
// Non-root node
111+
assert!(
112+
!self.nodes[node_idx].parent.is_none(),
113+
"Non-root node has no parent"
114+
)
115+
}
116+
}
117+
118+
let children_l: Vec<usize> = self.nodes.iter().filter_map(|node| node.left).collect();
119+
let children_r: Vec<usize> = self.nodes.iter().filter_map(|node| node.right).collect();
120+
let children = [children_l.clone(), children_r.clone()].concat();
121+
let mut seen = HashSet::new();
122+
let has_duplicates = children.iter().any(|item| !seen.insert(item));
123+
assert!(
124+
!has_duplicates,
125+
"Multiple nodes share 1 child. Children left: {:?}, Children right: {:?}",
126+
children_l, children_r
127+
);
128+
129+
// TODO: replace this test with a "Tree integrity" by starting from the root node, recursively
130+
// go to the child, check if the parent is correct.
131+
}
132+
133+
fn extend_mondrian_block(&mut self, node_idx: usize, x: &Array1<F>, label: &String) -> usize {
134+
println!("PRE_MONDRIAN");
135+
99136
// Collect necessary values for computations
100137
let parent_tau = self.get_parent_tau(node_idx);
101138
let tau = self.nodes[node_idx].tau;
@@ -168,6 +205,8 @@ impl<F: FType> MondrianTree<F> {
168205
self.nodes[node_idx].parent = Some(parent_idx);
169206

170207
self.update_internal(parent_idx); // Moved the update logic to a new method
208+
209+
return parent_idx;
171210
} else {
172211
let node = &mut self.nodes[node_idx];
173212
node.min_list.zip_mut_with(x, |a, b| *a = F::min(*a, *b));
@@ -184,6 +223,7 @@ impl<F: FType> MondrianTree<F> {
184223
} else {
185224
node.update_leaf(x, self.labels.iter().position(|l| l == label).unwrap());
186225
}
226+
return node_idx;
187227
}
188228
}
189229

@@ -200,11 +240,12 @@ impl<F: FType> MondrianTree<F> {
200240
///
201241
/// Function in River/LightRiver: "learn_one()"
202242
pub fn partial_fit(&mut self, x: &Array1<F>, y: &String) {
203-
if self.nodes.len() == 0 {
204-
self.create_leaf(x, y, None);
205-
} else {
206-
self.extend_mondrian_block(0, x, y);
207-
}
243+
println!("partial_fit() - post root: {:?}", self.root);
244+
self.root = match self.root {
245+
None => Some(self.create_leaf(x, y, None)),
246+
Some(root_idx) => Some(self.extend_mondrian_block(root_idx, x, y)),
247+
};
248+
println!("partial_fit() - post root: {:?}", self.root);
208249
}
209250

210251
fn fit(&self) {

src/datasets/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
pub mod credit_card;
22
pub mod keystroke;
3+
pub mod synthetic;
34
pub mod utils;

src/datasets/synthetic.rs

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
use crate::datasets::utils;
2+
use crate::stream::data_stream::Target;
3+
use crate::stream::iter_csv::IterCsv;
4+
use std::{fs::File, path::Path};
5+
6+
/// ChatGPT Generated synthetic dataset.
7+
///
8+
/// Add 'synthetic.csv' to project root directory.
9+
pub struct Synthetic;
10+
impl Synthetic {
11+
pub fn load_data() -> Result<IterCsv<f32, File>, Box<dyn std::error::Error>> {
12+
let file_name = "syntetic_dataset_int.csv";
13+
let file = File::open(file_name)?;
14+
let y_cols = Some(Target::Name("label".to_string()));
15+
match IterCsv::<f32, File>::new(file, y_cols) {
16+
Ok(x) => Ok(x),
17+
Err(e) => Err(Box::new(e)),
18+
}
19+
}
20+
}

0 commit comments

Comments
 (0)