@@ -41,8 +41,8 @@ impl<F: FType + fmt::Display> fmt::Display for MondrianTree<F> {
41
41
write ! ( f, "│ window_size: {}" , self . window_size) ?;
42
42
for ( i, node) in self . nodes . iter ( ) . enumerate ( ) {
43
43
writeln ! ( f) ?;
44
- // write!(f, "│ │ Node {}: left = {:?}, right = {:?}, parent = {:?}, tau = {}, min = {:?}, max = {:?}", i, node.left, node.right, node.parent, node.tau, node.min_list.to_vec(), node.max_list.to_vec())?;
45
44
write ! ( f, "│ │ Node {}: left={:?}, right={:?}, parent={:?}, tau={}, is_leaf={}, min={:?}, max={:?}" , i, node. left, node. right, node. parent, node. tau, node. is_leaf, node. min_list. to_vec( ) , node. max_list. to_vec( ) ) ?;
45
+ // write!(f, "│ │ Node {}: left={:?}, right={:?}, parent={:?}, is_leaf={}, min={:?}, max={:?}", i, node.left, node.right, node.parent, node.is_leaf, node.min_list.to_vec(), node.max_list.to_vec())?;
46
46
// write!(f, "│ │ Node {}: left={:?}, right={:?}, parent={:?}, tau={}, min={:?}, max={:?}", i, node.left, node.right, node.parent, node.tau, node.min_list.to_vec(), node.max_list.to_vec())?;
47
47
}
48
48
Ok ( ( ) )
@@ -62,30 +62,23 @@ impl<F: FType> MondrianTree<F> {
62
62
}
63
63
64
64
fn create_leaf ( & mut self , x : & Array1 < F > , label : & String , parent : Option < usize > ) -> usize {
65
- let min_list: ArrayBase < ndarray:: OwnedRepr < F > , Dim < [ usize ; 1 ] > > =
66
- Array1 :: zeros ( self . features . len ( ) ) ;
67
- let max_list = Array1 :: zeros ( self . features . len ( ) ) ;
68
-
69
65
let num_labels = self . labels . len ( ) ;
70
66
let feature_dim = self . features . len ( ) ;
71
- let labels = self . labels . clone ( ) ;
72
67
73
68
let mut node = Node :: < F > {
74
69
parent,
75
- tau : F :: from ( 1e9 ) . unwrap ( ) , // Very large value for tau
70
+ tau : F :: from ( 1e9 ) . unwrap ( ) , // Very large value
76
71
is_leaf : true ,
77
- min_list,
78
- max_list,
72
+ min_list : x . clone ( ) ,
73
+ max_list : x . clone ( ) ,
79
74
delta : 0 ,
80
75
xi : F :: zero ( ) ,
81
76
left : None ,
82
77
right : None ,
83
78
stats : Stats :: new ( num_labels, feature_dim) ,
84
79
} ;
85
80
86
- // TODO: check if this works:
87
- // labels: ["s002", "s003", "s004"]
88
- let label_idx = labels. iter ( ) . position ( |l| l == label) . unwrap ( ) ;
81
+ let label_idx = self . labels . clone ( ) . iter ( ) . position ( |l| l == label) . unwrap ( ) ;
89
82
node. update_leaf ( x, label_idx) ;
90
83
self . nodes . push ( node) ;
91
84
let node_idx = self . nodes . len ( ) - 1 ;
@@ -131,44 +124,52 @@ impl<F: FType> MondrianTree<F> {
131
124
}
132
125
133
126
fn extend_mondrian_block ( & mut self , node_idx : usize , x : & Array1 < F > , label : & String ) -> usize {
134
- println ! ( "PRE_MONDRIAN" ) ;
135
-
136
127
// Collect necessary values for computations
137
128
let parent_tau = self . get_parent_tau ( node_idx) ;
138
129
let tau = self . nodes [ node_idx] . tau ;
130
+ // TODO: 'node_min_list' and 'node_max_list' be accessible without cloning
139
131
let node_min_list = self . nodes [ node_idx] . min_list . clone ( ) ;
140
132
let node_max_list = self . nodes [ node_idx] . max_list . clone ( ) ;
141
133
142
134
let e_min = ( & node_min_list - x) . mapv ( |v| F :: max ( v, F :: zero ( ) ) ) ;
143
135
let e_max = ( x - & node_max_list) . mapv ( |v| F :: max ( v, F :: zero ( ) ) ) ;
136
+ // e_sum: size of the box [x_size, y_size]
144
137
let e_sum = & e_min + & e_max;
138
+ // 'rate' is lambda
145
139
let rate = e_sum. sum ( ) + F :: epsilon ( ) ;
146
140
let exp_dist = Exp :: new ( rate. to_f32 ( ) . unwrap ( ) ) . unwrap ( ) ;
147
- let E = F :: from_f32 ( exp_dist. sample ( & mut self . rng ) ) . unwrap ( ) ;
141
+ // 'exp_sample' is 'E' in nel215 code
142
+ let exp_sample = F :: from_f32 ( exp_dist. sample ( & mut self . rng ) ) . unwrap ( ) ;
143
+ // DEBUG: shadowing with Exp expected value
144
+ let exp_sample = F :: one ( ) / rate;
148
145
149
- if parent_tau + E < tau {
146
+ if parent_tau + exp_sample < tau {
150
147
let cumsum = e_sum
151
148
. iter ( )
152
149
. scan ( F :: zero ( ) , |acc, & x| {
153
150
* acc = * acc + x;
154
151
Some ( * acc)
155
152
} )
156
153
. collect :: < Array1 < F > > ( ) ;
157
- let e_sample =
158
- F :: from_f32 ( self . rng . gen :: < f32 > ( ) * e_sum. sum ( ) . to_f32 ( ) . unwrap ( ) ) . unwrap ( ) ;
154
+ // DEBUG: shadowing with expected value
155
+ let e_sample = F :: from_f32 ( self . rng . gen :: < f32 > ( ) ) . unwrap ( ) * e_sum. sum ( ) ;
156
+ let e_sample = F :: from_f32 ( 0.5 ) . unwrap ( ) * e_sum. sum ( ) ;
159
157
let delta = cumsum. iter ( ) . position ( |& val| val > e_sample) . unwrap_or ( 0 ) ;
160
- let xi =
161
- if x[ delta] > node_min_list[ delta] {
162
- F :: from_f32 ( self . rng . gen_range (
163
- node_min_list[ delta] . to_f32 ( ) . unwrap ( ) ..x[ delta] . to_f32 ( ) . unwrap ( ) ,
164
- ) )
165
- . unwrap ( )
166
- } else {
167
- F :: from_f32 ( self . rng . gen_range (
168
- x[ delta] . to_f32 ( ) . unwrap ( ) ..node_max_list[ delta] . to_f32 ( ) . unwrap ( ) ,
169
- ) )
170
- . unwrap ( )
171
- } ;
158
+
159
+ let ( lower_bound, upper_bound) = if x[ delta] > node_min_list[ delta] {
160
+ (
161
+ node_min_list[ delta] . to_f32 ( ) . unwrap ( ) ,
162
+ x[ delta] . to_f32 ( ) . unwrap ( ) ,
163
+ )
164
+ } else {
165
+ (
166
+ x[ delta] . to_f32 ( ) . unwrap ( ) ,
167
+ node_max_list[ delta] . to_f32 ( ) . unwrap ( ) ,
168
+ )
169
+ } ;
170
+ let xi = F :: from_f32 ( self . rng . gen_range ( lower_bound..upper_bound) ) . unwrap ( ) ;
171
+ // DEBUG: setting expected value
172
+ let xi = F :: from_f32 ( ( lower_bound + upper_bound) / 2.0 ) . unwrap ( ) ;
172
173
173
174
let mut min_list = node_min_list;
174
175
let mut max_list = node_max_list;
@@ -178,7 +179,7 @@ impl<F: FType> MondrianTree<F> {
178
179
// Create and push new parent node
179
180
let parent_node = Node {
180
181
parent : self . nodes [ node_idx] . parent ,
181
- tau : parent_tau + E ,
182
+ tau : parent_tau + exp_sample ,
182
183
is_leaf : false ,
183
184
min_list,
184
185
max_list,
@@ -188,13 +189,21 @@ impl<F: FType> MondrianTree<F> {
188
189
right : None ,
189
190
stats : Stats :: new ( self . labels . len ( ) , self . features . len ( ) ) ,
190
191
} ;
192
+ println ! (
193
+ "extend_mondrian_block() - mid if - grandpa: {:?}" ,
194
+ self . nodes[ node_idx] . parent
195
+ ) ;
191
196
192
197
self . nodes . push ( parent_node) ;
193
198
let parent_idx = self . nodes . len ( ) - 1 ;
194
199
let sibling_idx = self . create_leaf ( x, label, Some ( parent_idx) ) ;
195
200
196
201
// Set the children appropriately
197
202
if x[ delta] <= xi {
203
+ // Grandpa: self.nodes[node_idx].parent
204
+ // (new) Parent: parent_idx
205
+ // Child: node_idx
206
+ // (new) Sibling: sibling_idx
198
207
self . nodes [ parent_idx] . left = Some ( sibling_idx) ;
199
208
self . nodes [ parent_idx] . right = Some ( node_idx) ;
200
209
} else {
@@ -204,25 +213,47 @@ impl<F: FType> MondrianTree<F> {
204
213
205
214
self . nodes [ node_idx] . parent = Some ( parent_idx) ;
206
215
207
- self . update_internal ( parent_idx) ; // Moved the update logic to a new method
216
+ self . update_internal ( parent_idx) ;
217
+
218
+ println ! (
219
+ "extend_mondrian_block() - mid if - parent: {:?}, child: {:?}" ,
220
+ parent_idx, node_idx
221
+ ) ;
208
222
223
+ println ! ( "extend_modnrian_block() - post if" ) ;
209
224
return parent_idx;
210
225
} else {
211
226
let node = & mut self . nodes [ node_idx] ;
212
227
node. min_list . zip_mut_with ( x, |a, b| * a = F :: min ( * a, * b) ) ;
213
228
node. max_list . zip_mut_with ( x, |a, b| * a = F :: max ( * a, * b) ) ;
214
-
215
229
if !node. is_leaf {
216
- let child_idx = if x[ node. delta ] <= node. xi {
217
- node. left . unwrap ( )
230
+ println ! (
231
+ "extend_mondrian_block() - mid else - is_leaf - delta: {:?}, xi: {:?}" ,
232
+ node. delta, node. xi
233
+ ) ;
234
+ // TODO: understand how to make the following without making Rust angry with borrowing rules
235
+ // node.left = Some(self.extend_mondrian_block(node.left, x, label));
236
+ if x[ node. delta ] <= node. xi {
237
+ let node_left = node. left . unwrap ( ) ;
238
+ let node_left_new = Some ( self . extend_mondrian_block ( node_left, x, label) ) ;
239
+ let node = & mut self . nodes [ node_idx] ;
240
+ node. left = node_left_new;
218
241
} else {
219
- node. right . unwrap ( )
242
+ let node_right = node. right . unwrap ( ) ;
243
+ let node_right_new = Some ( self . extend_mondrian_block ( node_right, x, label) ) ;
244
+ let node = & mut self . nodes [ node_idx] ;
245
+ node. right = node_right_new;
220
246
} ;
221
- self . extend_mondrian_block ( child_idx, x, label) ;
222
- self . update_internal ( node_idx) ; // Moved the update logic to a new method
247
+ self . update_internal ( node_idx) ;
223
248
} else {
224
- node. update_leaf ( x, self . labels . iter ( ) . position ( |l| l == label) . unwrap ( ) ) ;
249
+ println ! (
250
+ "extend_mondrian_block() - mid else - is_leaf NOT - delta: {:?}, xi: {:?}" ,
251
+ node. delta, node. xi
252
+ ) ;
253
+ let label_idx = self . labels . iter ( ) . position ( |l| l == label) . unwrap ( ) ;
254
+ node. update_leaf ( x, label_idx) ;
225
255
}
256
+ println ! ( "extend_modnrian_block() - post else" ) ;
226
257
return node_idx;
227
258
}
228
259
}
@@ -240,12 +271,19 @@ impl<F: FType> MondrianTree<F> {
240
271
///
241
272
/// Function in River/LightRiver: "learn_one()"
242
273
pub fn partial_fit ( & mut self , x : & Array1 < F > , y : & String ) {
243
- println ! ( "partial_fit() - post root: {:?}" , self . root ) ;
274
+ // TODO: remove prints, roll back to previous version
244
275
self . root = match self . root {
245
- None => Some ( self . create_leaf ( x, y, None ) ) ,
246
- Some ( root_idx) => Some ( self . extend_mondrian_block ( root_idx, x, y) ) ,
276
+ None => {
277
+ let a = Some ( self . create_leaf ( x, y, None ) ) ;
278
+ println ! ( "create_leaf() - post {self}" ) ;
279
+ a
280
+ }
281
+ Some ( root_idx) => {
282
+ let a = Some ( self . extend_mondrian_block ( root_idx, x, y) ) ;
283
+ println ! ( "extend_modnrian_block() - post {self}" ) ;
284
+ a
285
+ }
247
286
} ;
248
- println ! ( "partial_fit() - post root: {:?}" , self . root) ;
249
287
}
250
288
251
289
fn fit ( & self ) {
@@ -318,7 +356,7 @@ impl<F: FType> MondrianTree<F> {
318
356
}
319
357
320
358
pub fn get_parent_tau ( & self , node_idx : usize ) -> F {
321
- // If node is root its time ( tau) is 0
359
+ // If node is root, tau is 0
322
360
match self . nodes [ node_idx] . parent {
323
361
Some ( parent_idx) => self . nodes [ parent_idx] . tau ,
324
362
None => F :: from_f32 ( 0.0 ) . unwrap ( ) ,
0 commit comments