Skip to content

Commit f79864d

Browse files
Fix pointer of grandpa on extend_mondrian_block
1 parent 667d35e commit f79864d

File tree

5 files changed

+91
-56
lines changed

5 files changed

+91
-56
lines changed

examples/classification/keystroke.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ fn get_features(transactions: IterCsv<f32, File>) -> Vec<String> {
2020
// TODO: pass transaction file by reference, in main use only one "Keystroke::load_data().unwrap()"
2121
let sample = transactions.into_iter().next();
2222
let observation = sample.unwrap().unwrap().get_observation();
23-
observation.iter().map(|(k, _)| k.clone()).collect()
23+
let mut out: Vec<String> = observation.iter().map(|(k, _)| k.clone()).collect();
24+
out.sort();
25+
out
2426
}
2527

2628
fn get_labels(transactions: IterCsv<f32, File>) -> Vec<String> {

examples/classification/synthetic.rs

+5-9
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ use std::time::Instant;
1919
fn get_features(transactions: IterCsv<f32, File>) -> Vec<String> {
2020
let sample = transactions.into_iter().next();
2121
let observation = sample.unwrap().unwrap().get_observation();
22-
observation.iter().map(|(k, _)| k.clone()).collect()
22+
let mut out: Vec<String> = observation.iter().map(|(k, _)| k.clone()).collect();
23+
out.sort();
24+
out
2325
}
2426

2527
fn get_labels(transactions: IterCsv<f32, File>) -> Vec<String> {
@@ -42,14 +44,10 @@ fn main() {
4244

4345
let transactions_f = Synthetic::load_data().unwrap();
4446
let features = get_features(transactions_f);
45-
// DEBUG: remove it
46-
// let features = features[0..2].to_vec();
4747

4848
let transactions_c = Synthetic::load_data().unwrap();
4949
let labels = get_labels(transactions_c);
50-
// DEBUG: remove it
51-
// let labels = labels[0..3].to_vec();
52-
println!("labels: {labels:?}");
50+
println!("labels: {labels:?}, features: {features:?}");
5351
let mut mf: MondrianForest<f32> = MondrianForest::new(window_size, n_trees, &features, &labels);
5452

5553
let transactions = Synthetic::load_data().unwrap();
@@ -65,16 +63,14 @@ fn main() {
6563
_ => unimplemented!(),
6664
};
6765
let x_ord = Array1::<f32>::from_vec(features.iter().map(|k| x[k]).collect());
68-
// DEBUG: remove it
69-
// let x_ord = x_ord.slice(s![0..2]).to_owned();
7066

7167
println!("=M=1 partial_fit {x_ord}");
7268
mf.partial_fit(&x_ord, &y);
7369

7470
println!("=M=2 predict_proba");
7571
let score = mf.predict_proba(&x_ord);
7672

77-
println!("=M=3 score: {:?}", score);
73+
println!("=M=3 score: {:?}", score.to_vec());
7874
println!("");
7975
}
8076

src/classification/mondrian_forest.rs

-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ impl<F: FType> MondrianForest<F> {
4545
pub fn partial_fit(&mut self, x: &Array1<F>, y: &String) {
4646
for tree in &mut self.trees {
4747
tree.partial_fit(x, y);
48-
println!("treeee {}", tree);
4948
}
5049
}
5150

src/classification/mondrian_node.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ pub struct Stats<F> {
7171
}
7272
impl<F: FType + fmt::Display> fmt::Display for Stats<F> {
7373
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74-
writeln!(f, "┌ Stats")?;
74+
writeln!(f, "\n┌ Stats")?;
7575
// sums
7676
write!(f, "│ sums: [")?;
7777
for row in self.sums.outer_iter() {

src/classification/mondrian_tree.rs

+82-44
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ impl<F: FType + fmt::Display> fmt::Display for MondrianTree<F> {
4141
write!(f, "│ window_size: {}", self.window_size)?;
4242
for (i, node) in self.nodes.iter().enumerate() {
4343
writeln!(f)?;
44-
// write!(f, "│ │ Node {}: left = {:?}, right = {:?}, parent = {:?}, tau = {}, min = {:?}, max = {:?}", i, node.left, node.right, node.parent, node.tau, node.min_list.to_vec(), node.max_list.to_vec())?;
4544
write!(f, "│ │ Node {}: left={:?}, right={:?}, parent={:?}, tau={}, is_leaf={}, min={:?}, max={:?}", i, node.left, node.right, node.parent, node.tau, node.is_leaf, node.min_list.to_vec(), node.max_list.to_vec())?;
45+
// write!(f, "│ │ Node {}: left={:?}, right={:?}, parent={:?}, is_leaf={}, min={:?}, max={:?}", i, node.left, node.right, node.parent, node.is_leaf, node.min_list.to_vec(), node.max_list.to_vec())?;
4646
// write!(f, "│ │ Node {}: left={:?}, right={:?}, parent={:?}, tau={}, min={:?}, max={:?}", i, node.left, node.right, node.parent, node.tau, node.min_list.to_vec(), node.max_list.to_vec())?;
4747
}
4848
Ok(())
@@ -62,30 +62,23 @@ impl<F: FType> MondrianTree<F> {
6262
}
6363

6464
fn create_leaf(&mut self, x: &Array1<F>, label: &String, parent: Option<usize>) -> usize {
65-
let min_list: ArrayBase<ndarray::OwnedRepr<F>, Dim<[usize; 1]>> =
66-
Array1::zeros(self.features.len());
67-
let max_list = Array1::zeros(self.features.len());
68-
6965
let num_labels = self.labels.len();
7066
let feature_dim = self.features.len();
71-
let labels = self.labels.clone();
7267

7368
let mut node = Node::<F> {
7469
parent,
75-
tau: F::from(1e9).unwrap(), // Very large value for tau
70+
tau: F::from(1e9).unwrap(), // Very large value
7671
is_leaf: true,
77-
min_list,
78-
max_list,
72+
min_list: x.clone(),
73+
max_list: x.clone(),
7974
delta: 0,
8075
xi: F::zero(),
8176
left: None,
8277
right: None,
8378
stats: Stats::new(num_labels, feature_dim),
8479
};
8580

86-
// TODO: check if this works:
87-
// labels: ["s002", "s003", "s004"]
88-
let label_idx = labels.iter().position(|l| l == label).unwrap();
81+
let label_idx = self.labels.clone().iter().position(|l| l == label).unwrap();
8982
node.update_leaf(x, label_idx);
9083
self.nodes.push(node);
9184
let node_idx = self.nodes.len() - 1;
@@ -131,44 +124,52 @@ impl<F: FType> MondrianTree<F> {
131124
}
132125

133126
fn extend_mondrian_block(&mut self, node_idx: usize, x: &Array1<F>, label: &String) -> usize {
134-
println!("PRE_MONDRIAN");
135-
136127
// Collect necessary values for computations
137128
let parent_tau = self.get_parent_tau(node_idx);
138129
let tau = self.nodes[node_idx].tau;
130+
// TODO: 'node_min_list' and 'node_max_list' be accessible without cloning
139131
let node_min_list = self.nodes[node_idx].min_list.clone();
140132
let node_max_list = self.nodes[node_idx].max_list.clone();
141133

142134
let e_min = (&node_min_list - x).mapv(|v| F::max(v, F::zero()));
143135
let e_max = (x - &node_max_list).mapv(|v| F::max(v, F::zero()));
136+
// e_sum: size of the box [x_size, y_size]
144137
let e_sum = &e_min + &e_max;
138+
// 'rate' is lambda
145139
let rate = e_sum.sum() + F::epsilon();
146140
let exp_dist = Exp::new(rate.to_f32().unwrap()).unwrap();
147-
let E = F::from_f32(exp_dist.sample(&mut self.rng)).unwrap();
141+
// 'exp_sample' is 'E' in nel215 code
142+
let exp_sample = F::from_f32(exp_dist.sample(&mut self.rng)).unwrap();
143+
// DEBUG: shadowing with Exp expected value
144+
let exp_sample = F::one() / rate;
148145

149-
if parent_tau + E < tau {
146+
if parent_tau + exp_sample < tau {
150147
let cumsum = e_sum
151148
.iter()
152149
.scan(F::zero(), |acc, &x| {
153150
*acc = *acc + x;
154151
Some(*acc)
155152
})
156153
.collect::<Array1<F>>();
157-
let e_sample =
158-
F::from_f32(self.rng.gen::<f32>() * e_sum.sum().to_f32().unwrap()).unwrap();
154+
// DEBUG: shadowing with expected value
155+
let e_sample = F::from_f32(self.rng.gen::<f32>()).unwrap() * e_sum.sum();
156+
let e_sample = F::from_f32(0.5).unwrap() * e_sum.sum();
159157
let delta = cumsum.iter().position(|&val| val > e_sample).unwrap_or(0);
160-
let xi =
161-
if x[delta] > node_min_list[delta] {
162-
F::from_f32(self.rng.gen_range(
163-
node_min_list[delta].to_f32().unwrap()..x[delta].to_f32().unwrap(),
164-
))
165-
.unwrap()
166-
} else {
167-
F::from_f32(self.rng.gen_range(
168-
x[delta].to_f32().unwrap()..node_max_list[delta].to_f32().unwrap(),
169-
))
170-
.unwrap()
171-
};
158+
159+
let (lower_bound, upper_bound) = if x[delta] > node_min_list[delta] {
160+
(
161+
node_min_list[delta].to_f32().unwrap(),
162+
x[delta].to_f32().unwrap(),
163+
)
164+
} else {
165+
(
166+
x[delta].to_f32().unwrap(),
167+
node_max_list[delta].to_f32().unwrap(),
168+
)
169+
};
170+
let xi = F::from_f32(self.rng.gen_range(lower_bound..upper_bound)).unwrap();
171+
// DEBUG: setting expected value
172+
let xi = F::from_f32((lower_bound + upper_bound) / 2.0).unwrap();
172173

173174
let mut min_list = node_min_list;
174175
let mut max_list = node_max_list;
@@ -178,7 +179,7 @@ impl<F: FType> MondrianTree<F> {
178179
// Create and push new parent node
179180
let parent_node = Node {
180181
parent: self.nodes[node_idx].parent,
181-
tau: parent_tau + E,
182+
tau: parent_tau + exp_sample,
182183
is_leaf: false,
183184
min_list,
184185
max_list,
@@ -188,13 +189,21 @@ impl<F: FType> MondrianTree<F> {
188189
right: None,
189190
stats: Stats::new(self.labels.len(), self.features.len()),
190191
};
192+
println!(
193+
"extend_mondrian_block() - mid if - grandpa: {:?}",
194+
self.nodes[node_idx].parent
195+
);
191196

192197
self.nodes.push(parent_node);
193198
let parent_idx = self.nodes.len() - 1;
194199
let sibling_idx = self.create_leaf(x, label, Some(parent_idx));
195200

196201
// Set the children appropriately
197202
if x[delta] <= xi {
203+
// Grandpa: self.nodes[node_idx].parent
204+
// (new) Parent: parent_idx
205+
// Child: node_idx
206+
// (new) Sibling: sibling_idx
198207
self.nodes[parent_idx].left = Some(sibling_idx);
199208
self.nodes[parent_idx].right = Some(node_idx);
200209
} else {
@@ -204,25 +213,47 @@ impl<F: FType> MondrianTree<F> {
204213

205214
self.nodes[node_idx].parent = Some(parent_idx);
206215

207-
self.update_internal(parent_idx); // Moved the update logic to a new method
216+
self.update_internal(parent_idx);
217+
218+
println!(
219+
"extend_mondrian_block() - mid if - parent: {:?}, child: {:?}",
220+
parent_idx, node_idx
221+
);
208222

223+
println!("extend_modnrian_block() - post if");
209224
return parent_idx;
210225
} else {
211226
let node = &mut self.nodes[node_idx];
212227
node.min_list.zip_mut_with(x, |a, b| *a = F::min(*a, *b));
213228
node.max_list.zip_mut_with(x, |a, b| *a = F::max(*a, *b));
214-
215229
if !node.is_leaf {
216-
let child_idx = if x[node.delta] <= node.xi {
217-
node.left.unwrap()
230+
println!(
231+
"extend_mondrian_block() - mid else - is_leaf - delta: {:?}, xi: {:?}",
232+
node.delta, node.xi
233+
);
234+
// TODO: understand how to make the following without making Rust angry with borrowing rules
235+
// node.left = Some(self.extend_mondrian_block(node.left, x, label));
236+
if x[node.delta] <= node.xi {
237+
let node_left = node.left.unwrap();
238+
let node_left_new = Some(self.extend_mondrian_block(node_left, x, label));
239+
let node = &mut self.nodes[node_idx];
240+
node.left = node_left_new;
218241
} else {
219-
node.right.unwrap()
242+
let node_right = node.right.unwrap();
243+
let node_right_new = Some(self.extend_mondrian_block(node_right, x, label));
244+
let node = &mut self.nodes[node_idx];
245+
node.right = node_right_new;
220246
};
221-
self.extend_mondrian_block(child_idx, x, label);
222-
self.update_internal(node_idx); // Moved the update logic to a new method
247+
self.update_internal(node_idx);
223248
} else {
224-
node.update_leaf(x, self.labels.iter().position(|l| l == label).unwrap());
249+
println!(
250+
"extend_mondrian_block() - mid else - is_leaf NOT - delta: {:?}, xi: {:?}",
251+
node.delta, node.xi
252+
);
253+
let label_idx = self.labels.iter().position(|l| l == label).unwrap();
254+
node.update_leaf(x, label_idx);
225255
}
256+
println!("extend_modnrian_block() - post else");
226257
return node_idx;
227258
}
228259
}
@@ -240,12 +271,19 @@ impl<F: FType> MondrianTree<F> {
240271
///
241272
/// Function in River/LightRiver: "learn_one()"
242273
pub fn partial_fit(&mut self, x: &Array1<F>, y: &String) {
243-
println!("partial_fit() - post root: {:?}", self.root);
274+
// TODO: remove prints, roll back to previous version
244275
self.root = match self.root {
245-
None => Some(self.create_leaf(x, y, None)),
246-
Some(root_idx) => Some(self.extend_mondrian_block(root_idx, x, y)),
276+
None => {
277+
let a = Some(self.create_leaf(x, y, None));
278+
println!("create_leaf() - post {self}");
279+
a
280+
}
281+
Some(root_idx) => {
282+
let a = Some(self.extend_mondrian_block(root_idx, x, y));
283+
println!("extend_modnrian_block() - post {self}");
284+
a
285+
}
247286
};
248-
println!("partial_fit() - post root: {:?}", self.root);
249287
}
250288

251289
fn fit(&self) {
@@ -318,7 +356,7 @@ impl<F: FType> MondrianTree<F> {
318356
}
319357

320358
pub fn get_parent_tau(&self, node_idx: usize) -> F {
321-
// If node is root its time (tau) is 0
359+
// If node is root, tau is 0
322360
match self.nodes[node_idx].parent {
323361
Some(parent_idx) => self.nodes[parent_idx].tau,
324362
None => F::from_f32(0.0).unwrap(),

0 commit comments

Comments
 (0)