@@ -13,6 +13,7 @@ use rand::prelude::*;
13
13
use rand_distr:: { Distribution , Exp } ;
14
14
use std:: cell:: RefCell ;
15
15
use std:: collections:: HashMap ;
16
+ use std:: collections:: HashSet ;
16
17
use std:: convert:: TryFrom ;
17
18
use std:: env:: consts;
18
19
use std:: fmt;
@@ -31,6 +32,7 @@ pub struct MondrianTree<F: FType> {
31
32
rng : ThreadRng ,
32
33
first_learn : bool ,
33
34
nodes : Vec < Node < F > > ,
35
+ root : Option < usize > ,
34
36
}
35
37
impl < F : FType + fmt:: Display > fmt:: Display for MondrianTree < F > {
36
38
fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
@@ -39,22 +41,23 @@ impl<F: FType + fmt::Display> fmt::Display for MondrianTree<F> {
39
41
write ! ( f, "│ window_size: {}" , self . window_size) ?;
40
42
for ( i, node) in self . nodes . iter ( ) . enumerate ( ) {
41
43
writeln ! ( f) ?;
42
- write ! ( f, "│ │ Node {}: left = {:?}, right = {:?}, parent = {:?}, tau = {}, is_leaf = {}, min = {:?}, max = {:?}" , i, node. left, node. right, node. parent, node. tau, node. is_leaf, node. min_list. to_vec( ) , node. max_list. to_vec( ) ) ?;
44
+ // write!(f, "│ │ Node {}: left = {:?}, right = {:?}, parent = {:?}, tau = {}, min = {:?}, max = {:?}", i, node.left, node.right, node.parent, node.tau, node.min_list.to_vec(), node.max_list.to_vec())?;
45
+ write ! ( f, "│ │ Node {}: left={:?}, right={:?}, parent={:?}, tau={}, is_leaf={}, min={:?}, max={:?}" , i, node. left, node. right, node. parent, node. tau, node. is_leaf, node. min_list. to_vec( ) , node. max_list. to_vec( ) ) ?;
46
+ // write!(f, "│ │ Node {}: left={:?}, right={:?}, parent={:?}, tau={}, min={:?}, max={:?}", i, node.left, node.right, node.parent, node.tau, node.min_list.to_vec(), node.max_list.to_vec())?;
43
47
}
44
48
Ok ( ( ) )
45
49
}
46
50
}
47
51
impl < F : FType > MondrianTree < F > {
48
52
pub fn new ( window_size : usize , features : & Vec < String > , labels : & Vec < String > ) -> Self {
49
- let mut rng = rand:: thread_rng ( ) ;
50
- let nodes = vec ! [ ] ;
51
53
MondrianTree :: < F > {
52
54
window_size,
53
55
features : features. clone ( ) ,
54
56
labels : labels. clone ( ) ,
55
- rng,
57
+ rng : rand :: thread_rng ( ) ,
56
58
first_learn : false ,
57
- nodes,
59
+ nodes : vec ! [ ] ,
60
+ root : None ,
58
61
}
59
62
}
60
63
@@ -79,6 +82,7 @@ impl<F: FType> MondrianTree<F> {
79
82
right : None ,
80
83
stats : Stats :: new ( num_labels, feature_dim) ,
81
84
} ;
85
+
82
86
// TODO: check if this works:
83
87
// labels: ["s002", "s003", "s004"]
84
88
let label_idx = labels. iter ( ) . position ( |l| l == label) . unwrap ( ) ;
@@ -92,10 +96,43 @@ impl<F: FType> MondrianTree<F> {
92
96
/// working only on one, so it's the same as "predict()".
93
97
pub fn predict_proba ( & self , x : & Array1 < F > ) -> Array1 < F > {
94
98
let root = 0 ;
99
+ self . test_tree ( ) ;
95
100
self . predict ( x, root, F :: one ( ) )
96
101
}
97
102
98
- fn extend_mondrian_block ( & mut self , node_idx : usize , x : & Array1 < F > , label : & String ) {
103
+ fn test_tree ( & self ) {
104
+ for node_idx in 0 ..self . nodes . len ( ) {
105
+ // TODO: check if self.root is None, if so tree should be empty
106
+ if node_idx == self . root . unwrap ( ) {
107
+ // Root node
108
+ assert ! ( self . nodes[ node_idx] . parent. is_none( ) , "Root has a parent." ) ;
109
+ } else {
110
+ // Non-root node
111
+ assert ! (
112
+ !self . nodes[ node_idx] . parent. is_none( ) ,
113
+ "Non-root node has no parent"
114
+ )
115
+ }
116
+ }
117
+
118
+ let children_l: Vec < usize > = self . nodes . iter ( ) . filter_map ( |node| node. left ) . collect ( ) ;
119
+ let children_r: Vec < usize > = self . nodes . iter ( ) . filter_map ( |node| node. right ) . collect ( ) ;
120
+ let children = [ children_l. clone ( ) , children_r. clone ( ) ] . concat ( ) ;
121
+ let mut seen = HashSet :: new ( ) ;
122
+ let has_duplicates = children. iter ( ) . any ( |item| !seen. insert ( item) ) ;
123
+ assert ! (
124
+ !has_duplicates,
125
+ "Multiple nodes share 1 child. Children left: {:?}, Children right: {:?}" ,
126
+ children_l, children_r
127
+ ) ;
128
+
129
+ // TODO: replace this test with a "Tree integrity" by starting from the root node, recursively
130
+ // go to the child, check if the parent is correct.
131
+ }
132
+
133
+ fn extend_mondrian_block ( & mut self , node_idx : usize , x : & Array1 < F > , label : & String ) -> usize {
134
+ println ! ( "PRE_MONDRIAN" ) ;
135
+
99
136
// Collect necessary values for computations
100
137
let parent_tau = self . get_parent_tau ( node_idx) ;
101
138
let tau = self . nodes [ node_idx] . tau ;
@@ -168,6 +205,8 @@ impl<F: FType> MondrianTree<F> {
168
205
self . nodes [ node_idx] . parent = Some ( parent_idx) ;
169
206
170
207
self . update_internal ( parent_idx) ; // Moved the update logic to a new method
208
+
209
+ return parent_idx;
171
210
} else {
172
211
let node = & mut self . nodes [ node_idx] ;
173
212
node. min_list . zip_mut_with ( x, |a, b| * a = F :: min ( * a, * b) ) ;
@@ -184,6 +223,7 @@ impl<F: FType> MondrianTree<F> {
184
223
} else {
185
224
node. update_leaf ( x, self . labels . iter ( ) . position ( |l| l == label) . unwrap ( ) ) ;
186
225
}
226
+ return node_idx;
187
227
}
188
228
}
189
229
@@ -200,11 +240,12 @@ impl<F: FType> MondrianTree<F> {
200
240
///
201
241
/// Function in River/LightRiver: "learn_one()"
202
242
pub fn partial_fit ( & mut self , x : & Array1 < F > , y : & String ) {
203
- if self . nodes . len ( ) == 0 {
204
- self . create_leaf ( x, y, None ) ;
205
- } else {
206
- self . extend_mondrian_block ( 0 , x, y) ;
207
- }
243
+ println ! ( "partial_fit() - post root: {:?}" , self . root) ;
244
+ self . root = match self . root {
245
+ None => Some ( self . create_leaf ( x, y, None ) ) ,
246
+ Some ( root_idx) => Some ( self . extend_mondrian_block ( root_idx, x, y) ) ,
247
+ } ;
248
+ println ! ( "partial_fit() - post root: {:?}" , self . root) ;
208
249
}
209
250
210
251
fn fit ( & self ) {
0 commit comments