2
2
//! The purpose of this pass is to inline the instructions of each function call
3
3
//! within the function caller. If all function calls are known, there will only
4
4
//! be a single function remaining when the pass finishes.
5
- use std:: collections:: { BTreeMap , BTreeSet , HashSet , VecDeque } ;
5
+ use std:: collections:: { BTreeSet , HashSet , VecDeque } ;
6
6
7
7
use acvm:: acir:: AcirField ;
8
8
use im:: HashMap ;
@@ -21,6 +21,10 @@ use crate::ssa::{
21
21
ssa_gen:: Ssa ,
22
22
} ;
23
23
24
+ pub ( super ) mod inline_info;
25
+
26
+ pub ( super ) use inline_info:: { compute_inline_infos, InlineInfo , InlineInfos } ;
27
+
24
28
/// An arbitrary limit to the maximum number of recursive call
25
29
/// frames at any point in time.
26
30
const RECURSION_LIMIT : u32 = 1000 ;
@@ -206,366 +210,7 @@ fn called_functions(func: &Function) -> BTreeSet<FunctionId> {
206
210
called_functions_vec ( func) . into_iter ( ) . collect ( )
207
211
}
208
212
209
- /// Information about a function to aid the decision about whether to inline it or not.
210
- /// The final decision depends on what we're inlining it into.
211
- #[ derive( Default , Debug ) ]
212
- pub ( super ) struct InlineInfo {
213
- is_brillig_entry_point : bool ,
214
- is_acir_entry_point : bool ,
215
- is_recursive : bool ,
216
- pub ( super ) should_inline : bool ,
217
- weight : i64 ,
218
- cost : i64 ,
219
- }
220
-
221
- impl InlineInfo {
222
- /// Functions which are to be retained, not inlined.
223
- pub ( super ) fn is_inline_target ( & self ) -> bool {
224
- self . is_brillig_entry_point
225
- || self . is_acir_entry_point
226
- || self . is_recursive
227
- || !self . should_inline
228
- }
229
-
230
- pub ( super ) fn should_inline ( inline_infos : & InlineInfos , called_func_id : FunctionId ) -> bool {
231
- inline_infos. get ( & called_func_id) . map ( |info| info. should_inline ) . unwrap_or_default ( )
232
- }
233
- }
234
-
235
- type InlineInfos = BTreeMap < FunctionId , InlineInfo > ;
236
-
237
- /// The functions we should inline into (and that should be left in the final program) are:
238
- /// - main
239
- /// - Any Brillig function called from Acir
240
- /// - Some Brillig functions depending on aggressiveness and some metrics
241
- /// - Any Acir functions with a [fold inline type][InlineType::Fold],
242
- ///
243
- /// The returned `InlineInfos` won't have every function in it, only the ones which the algorithm visited.
244
- pub ( super ) fn compute_inline_infos (
245
- ssa : & Ssa ,
246
- inline_no_predicates_functions : bool ,
247
- aggressiveness : i64 ,
248
- ) -> InlineInfos {
249
- let mut inline_infos = InlineInfos :: default ( ) ;
250
-
251
- inline_infos. insert (
252
- ssa. main_id ,
253
- InlineInfo {
254
- is_acir_entry_point : ssa. main ( ) . runtime ( ) . is_acir ( ) ,
255
- is_brillig_entry_point : ssa. main ( ) . runtime ( ) . is_brillig ( ) ,
256
- ..Default :: default ( )
257
- } ,
258
- ) ;
259
-
260
- // Handle ACIR functions.
261
- for ( func_id, function) in ssa. functions . iter ( ) {
262
- if function. runtime ( ) . is_brillig ( ) {
263
- continue ;
264
- }
265
-
266
- // If we have not already finished the flattening pass, functions marked
267
- // to not have predicates should be preserved.
268
- let preserve_function = !inline_no_predicates_functions && function. is_no_predicates ( ) ;
269
- if function. runtime ( ) . is_entry_point ( ) || preserve_function {
270
- inline_infos. entry ( * func_id) . or_default ( ) . is_acir_entry_point = true ;
271
- }
272
-
273
- // Any Brillig function called from ACIR is an entry into the Brillig VM.
274
- for called_func_id in called_functions ( function) {
275
- if ssa. functions [ & called_func_id] . runtime ( ) . is_brillig ( ) {
276
- inline_infos. entry ( called_func_id) . or_default ( ) . is_brillig_entry_point = true ;
277
- }
278
- }
279
- }
280
-
281
- let callers = compute_callers ( ssa) ;
282
- let times_called = compute_times_called ( & callers) ;
283
-
284
- mark_brillig_functions_to_retain (
285
- ssa,
286
- inline_no_predicates_functions,
287
- aggressiveness,
288
- & times_called,
289
- & mut inline_infos,
290
- ) ;
291
-
292
- inline_infos
293
- }
294
-
295
- /// Compute the time each function is called from any other function.
296
- fn compute_times_called (
297
- callers : & BTreeMap < FunctionId , BTreeMap < FunctionId , usize > > ,
298
- ) -> HashMap < FunctionId , usize > {
299
- callers
300
- . iter ( )
301
- . map ( |( callee, callers) | {
302
- let total_calls = callers. values ( ) . sum ( ) ;
303
- ( * callee, total_calls)
304
- } )
305
- . collect ( )
306
- }
307
-
308
- /// Compute for each function the set of functions that call it, and how many times they do so.
309
- fn compute_callers ( ssa : & Ssa ) -> BTreeMap < FunctionId , BTreeMap < FunctionId , usize > > {
310
- ssa. functions
311
- . iter ( )
312
- . flat_map ( |( caller_id, function) | {
313
- let called_functions = called_functions_vec ( function) ;
314
- called_functions. into_iter ( ) . map ( |callee_id| ( * caller_id, callee_id) )
315
- } )
316
- . fold (
317
- // Make sure an entry exists even for ones that don't get called.
318
- ssa. functions . keys ( ) . map ( |id| ( * id, BTreeMap :: new ( ) ) ) . collect ( ) ,
319
- |mut acc, ( caller_id, callee_id) | {
320
- let callers = acc. entry ( callee_id) . or_default ( ) ;
321
- * callers. entry ( caller_id) . or_default ( ) += 1 ;
322
- acc
323
- } ,
324
- )
325
- }
326
-
327
- /// Compute for each function the set of functions called by it, and how many times it does so.
328
- fn compute_callees ( ssa : & Ssa ) -> BTreeMap < FunctionId , BTreeMap < FunctionId , usize > > {
329
- ssa. functions
330
- . iter ( )
331
- . flat_map ( |( caller_id, function) | {
332
- let called_functions = called_functions_vec ( function) ;
333
- called_functions. into_iter ( ) . map ( |callee_id| ( * caller_id, callee_id) )
334
- } )
335
- . fold (
336
- // Make sure an entry exists even for ones that don't call anything.
337
- ssa. functions . keys ( ) . map ( |id| ( * id, BTreeMap :: new ( ) ) ) . collect ( ) ,
338
- |mut acc, ( caller_id, callee_id) | {
339
- let callees = acc. entry ( caller_id) . or_default ( ) ;
340
- * callees. entry ( callee_id) . or_default ( ) += 1 ;
341
- acc
342
- } ,
343
- )
344
- }
345
213
346
- /// Compute something like a topological order of the functions, starting with the ones
347
- /// that do not call any other functions, going towards the entry points. When cycles
348
- /// are detected, take the one which are called by the most to break the ties.
349
- ///
350
- /// This can be used to simplify the most often called functions first.
351
- ///
352
- /// Returns the functions paired with their own as well as transitive weight,
353
- /// which accumulates the weight of all the functions they call, as well as own.
354
- pub ( super ) fn compute_bottom_up_order ( ssa : & Ssa ) -> Vec < ( FunctionId , ( usize , usize ) ) > {
355
- let mut order = Vec :: new ( ) ;
356
- let mut visited = HashSet :: new ( ) ;
357
-
358
- // Call graph which we'll repeatedly prune to find the "leaves".
359
- let mut callees = compute_callees ( ssa) ;
360
- let callers = compute_callers ( ssa) ;
361
-
362
- // Number of times a function is called, used to break cycles in the call graph by popping the next candidate.
363
- let mut times_called = compute_times_called ( & callers) . into_iter ( ) . collect :: < Vec < _ > > ( ) ;
364
- times_called. sort_by_key ( |( id, cnt) | {
365
- // Sort by called the *least* by others, as these are less likely to cut the graph when removed.
366
- let called_desc = -( * cnt as i64 ) ;
367
- // Sort entries first (last to be popped).
368
- let is_entry_asc = -called_desc. signum ( ) ;
369
- // Finally break ties by ID.
370
- ( is_entry_asc, called_desc, * id)
371
- } ) ;
372
-
373
- // Start with the weight of the functions in isolation, then accumulate as we pop off the ones they call.
374
- let own_weights = ssa
375
- . functions
376
- . iter ( )
377
- . map ( |( id, f) | ( * id, compute_function_own_weight ( f) ) )
378
- . collect :: < HashMap < _ , _ > > ( ) ;
379
- let mut weights = own_weights. clone ( ) ;
380
-
381
- // Seed the queue with functions that don't call anything.
382
- let mut queue = callees
383
- . iter ( )
384
- . filter_map ( |( id, callees) | callees. is_empty ( ) . then_some ( * id) )
385
- . collect :: < VecDeque < _ > > ( ) ;
386
-
387
- loop {
388
- while let Some ( id) = queue. pop_front ( ) {
389
- // Pull the current weight of yet-to-be emitted callees (a nod to mutual recursion).
390
- for ( callee, cnt) in & callees[ & id] {
391
- if * callee != id {
392
- weights[ & id] = weights[ & id] . saturating_add ( cnt. saturating_mul ( weights[ callee] ) ) ;
393
- }
394
- }
395
- // Own weight plus the weights accumulated from callees.
396
- let weight = weights[ & id] ;
397
- let own_weight = own_weights[ & id] ;
398
-
399
- // Emit the function.
400
- order. push ( ( id, ( own_weight, weight) ) ) ;
401
- visited. insert ( id) ;
402
-
403
- // Update the callers of this function.
404
- for ( caller, cnt) in & callers[ & id] {
405
- // Update the weight of the caller with the weight of this function.
406
- weights[ caller] = weights[ caller] . saturating_add ( cnt. saturating_mul ( weight) ) ;
407
- // Remove this function from the callees of the caller.
408
- let callees = callees. get_mut ( caller) . unwrap ( ) ;
409
- callees. remove ( & id) ;
410
- // If the caller doesn't call any other function, enqueue it,
411
- // unless it's the entry function, which is never called by anything, so it should be last.
412
- if callees. is_empty ( ) && !visited. contains ( caller) && !callers[ caller] . is_empty ( ) {
413
- queue. push_back ( * caller) ;
414
- }
415
- }
416
- }
417
- // If we ran out of the queue, maybe there is a cycle; take the next most called function.
418
- while let Some ( ( id, _) ) = times_called. pop ( ) {
419
- if !visited. contains ( & id) {
420
- queue. push_back ( id) ;
421
- break ;
422
- }
423
- }
424
- if times_called. is_empty ( ) && queue. is_empty ( ) {
425
- assert_eq ! ( order. len( ) , callers. len( ) ) ;
426
- return order;
427
- }
428
- }
429
- }
430
-
431
- /// Traverse the call graph starting from a given function, marking function to be retained if they are:
432
- /// * recursive functions, or
433
- /// * the cost of inlining outweighs the cost of not doing so
434
- fn mark_functions_to_retain_recursive (
435
- ssa : & Ssa ,
436
- inline_no_predicates_functions : bool ,
437
- aggressiveness : i64 ,
438
- times_called : & HashMap < FunctionId , usize > ,
439
- inline_infos : & mut InlineInfos ,
440
- mut explored_functions : im:: HashSet < FunctionId > ,
441
- func : FunctionId ,
442
- ) {
443
- // Check if we have set any of the fields this method touches.
444
- let decided = |inline_infos : & InlineInfos | {
445
- inline_infos
446
- . get ( & func)
447
- . map ( |info| info. is_recursive || info. should_inline || info. weight != 0 )
448
- . unwrap_or_default ( )
449
- } ;
450
-
451
- // Check if we have already decided on this function
452
- if decided ( inline_infos) {
453
- return ;
454
- }
455
-
456
- // If recursive, this function won't be inlined
457
- if explored_functions. contains ( & func) {
458
- inline_infos. entry ( func) . or_default ( ) . is_recursive = true ;
459
- return ;
460
- }
461
- explored_functions. insert ( func) ;
462
-
463
- // Decide on dependencies first, so we know their weight.
464
- let called_functions = called_functions_vec ( & ssa. functions [ & func] ) ;
465
- for callee in & called_functions {
466
- mark_functions_to_retain_recursive (
467
- ssa,
468
- inline_no_predicates_functions,
469
- aggressiveness,
470
- times_called,
471
- inline_infos,
472
- explored_functions. clone ( ) ,
473
- * callee,
474
- ) ;
475
- }
476
-
477
- // We could have decided on this function while deciding on dependencies
478
- // if the function is recursive.
479
- if decided ( inline_infos) {
480
- return ;
481
- }
482
-
483
- // We'll use some heuristics to decide whether to inline or not.
484
- // We compute the weight (roughly the number of instructions) of the function after inlining
485
- // And the interface cost of the function (the inherent cost at the callsite, roughly the number of args and returns)
486
- // We then can compute an approximation of the cost of inlining vs the cost of retaining the function
487
- // We do this computation using saturating i64s to avoid overflows,
488
- // and because we want to calculate a difference which can be negative.
489
-
490
- // Total weight of functions called by this one, unless we decided not to inline them.
491
- // Callees which appear multiple times would be inlined multiple times.
492
- let inlined_function_weights: i64 = called_functions. iter ( ) . fold ( 0 , |acc, callee| {
493
- let info = & inline_infos[ callee] ;
494
- // If the callee is not going to be inlined then we can ignore its cost.
495
- if info. should_inline {
496
- acc. saturating_add ( info. weight )
497
- } else {
498
- acc
499
- }
500
- } ) ;
501
-
502
- let this_function_weight = inlined_function_weights
503
- . saturating_add ( compute_function_own_weight ( & ssa. functions [ & func] ) as i64 ) ;
504
-
505
- let interface_cost = compute_function_interface_cost ( & ssa. functions [ & func] ) as i64 ;
506
-
507
- let times_called = times_called[ & func] as i64 ;
508
-
509
- let inline_cost = times_called. saturating_mul ( this_function_weight) ;
510
- let retain_cost = times_called. saturating_mul ( interface_cost) + this_function_weight;
511
- let net_cost = inline_cost. saturating_sub ( retain_cost) ;
512
-
513
- let runtime = ssa. functions [ & func] . runtime ( ) ;
514
- // We inline if the aggressiveness is higher than inline cost minus the retain cost
515
- // If aggressiveness is infinite, we'll always inline
516
- // If aggressiveness is 0, we'll inline when the inline cost is lower than the retain cost
517
- // If aggressiveness is minus infinity, we'll never inline (other than in the mandatory cases)
518
- let should_inline = ( net_cost < aggressiveness)
519
- || runtime. is_inline_always ( )
520
- || ( runtime. is_no_predicates ( ) && inline_no_predicates_functions) ;
521
-
522
- let info = inline_infos. entry ( func) . or_default ( ) ;
523
- info. should_inline = should_inline;
524
- info. weight = this_function_weight;
525
- info. cost = net_cost;
526
- }
527
-
528
- /// Mark Brillig functions that should not be inlined because they are recursive or expensive.
529
- fn mark_brillig_functions_to_retain (
530
- ssa : & Ssa ,
531
- inline_no_predicates_functions : bool ,
532
- aggressiveness : i64 ,
533
- times_called : & HashMap < FunctionId , usize > ,
534
- inline_infos : & mut InlineInfos ,
535
- ) {
536
- let brillig_entry_points = inline_infos
537
- . iter ( )
538
- . filter_map ( |( id, info) | info. is_brillig_entry_point . then_some ( * id) )
539
- . collect :: < Vec < _ > > ( ) ;
540
-
541
- for entry_point in brillig_entry_points {
542
- mark_functions_to_retain_recursive (
543
- ssa,
544
- inline_no_predicates_functions,
545
- aggressiveness,
546
- times_called,
547
- inline_infos,
548
- im:: HashSet :: default ( ) ,
549
- entry_point,
550
- ) ;
551
- }
552
- }
553
-
554
- /// Compute a weight of a function based on the number of instructions in its reachable blocks.
555
- fn compute_function_own_weight ( func : & Function ) -> usize {
556
- let mut weight = 0 ;
557
- for block_id in func. reachable_blocks ( ) {
558
- weight += func. dfg [ block_id] . instructions ( ) . len ( ) + 1 ; // We add one for the terminator
559
- }
560
- // We use an approximation of the average increase in instruction ratio from SSA to Brillig
561
- // In order to get the actual weight we'd need to codegen this function to brillig.
562
- weight
563
- }
564
-
565
- /// Compute interface cost of a function based on the number of inputs and outputs.
566
- fn compute_function_interface_cost ( func : & Function ) -> usize {
567
- func. parameters ( ) . len ( ) + func. returns ( ) . len ( )
568
- }
569
214
570
215
impl InlineContext {
571
216
/// Create a new context object for the function inlining pass.
0 commit comments