Skip to content

Commit 11bc7e4

Browse files
committed
chore: push inlining info code into a submodule
1 parent 0b6c382 commit 11bc7e4

File tree

3 files changed

+378
-362
lines changed

3 files changed

+378
-362
lines changed

compiler/noirc_evaluator/src/ssa/opt/inlining.rs

+5-360
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
//! The purpose of this pass is to inline the instructions of each function call
33
//! within the function caller. If all function calls are known, there will only
44
//! be a single function remaining when the pass finishes.
5-
use std::collections::{BTreeMap, BTreeSet, HashSet, VecDeque};
5+
use std::collections::{BTreeSet, HashSet, VecDeque};
66

77
use acvm::acir::AcirField;
88
use im::HashMap;
@@ -21,6 +21,10 @@ use crate::ssa::{
2121
ssa_gen::Ssa,
2222
};
2323

24+
pub(super) mod inline_info;
25+
26+
pub(super) use inline_info::{compute_inline_infos, InlineInfo, InlineInfos};
27+
2428
/// An arbitrary limit to the maximum number of recursive call
2529
/// frames at any point in time.
2630
const RECURSION_LIMIT: u32 = 1000;
@@ -206,366 +210,7 @@ fn called_functions(func: &Function) -> BTreeSet<FunctionId> {
206210
called_functions_vec(func).into_iter().collect()
207211
}
208212

209-
/// Information about a function to aid the decision about whether to inline it or not.
210-
/// The final decision depends on what we're inlining it into.
211-
#[derive(Default, Debug)]
212-
pub(super) struct InlineInfo {
213-
is_brillig_entry_point: bool,
214-
is_acir_entry_point: bool,
215-
is_recursive: bool,
216-
pub(super) should_inline: bool,
217-
weight: i64,
218-
cost: i64,
219-
}
220-
221-
impl InlineInfo {
222-
/// Functions which are to be retained, not inlined.
223-
pub(super) fn is_inline_target(&self) -> bool {
224-
self.is_brillig_entry_point
225-
|| self.is_acir_entry_point
226-
|| self.is_recursive
227-
|| !self.should_inline
228-
}
229-
230-
pub(super) fn should_inline(inline_infos: &InlineInfos, called_func_id: FunctionId) -> bool {
231-
inline_infos.get(&called_func_id).map(|info| info.should_inline).unwrap_or_default()
232-
}
233-
}
234-
235-
type InlineInfos = BTreeMap<FunctionId, InlineInfo>;
236-
237-
/// The functions we should inline into (and that should be left in the final program) are:
238-
/// - main
239-
/// - Any Brillig function called from Acir
240-
/// - Some Brillig functions depending on aggressiveness and some metrics
241-
/// - Any Acir functions with a [fold inline type][InlineType::Fold],
242-
///
243-
/// The returned `InlineInfos` won't have every function in it, only the ones which the algorithm visited.
244-
pub(super) fn compute_inline_infos(
245-
ssa: &Ssa,
246-
inline_no_predicates_functions: bool,
247-
aggressiveness: i64,
248-
) -> InlineInfos {
249-
let mut inline_infos = InlineInfos::default();
250-
251-
inline_infos.insert(
252-
ssa.main_id,
253-
InlineInfo {
254-
is_acir_entry_point: ssa.main().runtime().is_acir(),
255-
is_brillig_entry_point: ssa.main().runtime().is_brillig(),
256-
..Default::default()
257-
},
258-
);
259-
260-
// Handle ACIR functions.
261-
for (func_id, function) in ssa.functions.iter() {
262-
if function.runtime().is_brillig() {
263-
continue;
264-
}
265-
266-
// If we have not already finished the flattening pass, functions marked
267-
// to not have predicates should be preserved.
268-
let preserve_function = !inline_no_predicates_functions && function.is_no_predicates();
269-
if function.runtime().is_entry_point() || preserve_function {
270-
inline_infos.entry(*func_id).or_default().is_acir_entry_point = true;
271-
}
272-
273-
// Any Brillig function called from ACIR is an entry into the Brillig VM.
274-
for called_func_id in called_functions(function) {
275-
if ssa.functions[&called_func_id].runtime().is_brillig() {
276-
inline_infos.entry(called_func_id).or_default().is_brillig_entry_point = true;
277-
}
278-
}
279-
}
280-
281-
let callers = compute_callers(ssa);
282-
let times_called = compute_times_called(&callers);
283-
284-
mark_brillig_functions_to_retain(
285-
ssa,
286-
inline_no_predicates_functions,
287-
aggressiveness,
288-
&times_called,
289-
&mut inline_infos,
290-
);
291-
292-
inline_infos
293-
}
294-
295-
/// Compute the time each function is called from any other function.
296-
fn compute_times_called(
297-
callers: &BTreeMap<FunctionId, BTreeMap<FunctionId, usize>>,
298-
) -> HashMap<FunctionId, usize> {
299-
callers
300-
.iter()
301-
.map(|(callee, callers)| {
302-
let total_calls = callers.values().sum();
303-
(*callee, total_calls)
304-
})
305-
.collect()
306-
}
307-
308-
/// Compute for each function the set of functions that call it, and how many times they do so.
309-
fn compute_callers(ssa: &Ssa) -> BTreeMap<FunctionId, BTreeMap<FunctionId, usize>> {
310-
ssa.functions
311-
.iter()
312-
.flat_map(|(caller_id, function)| {
313-
let called_functions = called_functions_vec(function);
314-
called_functions.into_iter().map(|callee_id| (*caller_id, callee_id))
315-
})
316-
.fold(
317-
// Make sure an entry exists even for ones that don't get called.
318-
ssa.functions.keys().map(|id| (*id, BTreeMap::new())).collect(),
319-
|mut acc, (caller_id, callee_id)| {
320-
let callers = acc.entry(callee_id).or_default();
321-
*callers.entry(caller_id).or_default() += 1;
322-
acc
323-
},
324-
)
325-
}
326-
327-
/// Compute for each function the set of functions called by it, and how many times it does so.
328-
fn compute_callees(ssa: &Ssa) -> BTreeMap<FunctionId, BTreeMap<FunctionId, usize>> {
329-
ssa.functions
330-
.iter()
331-
.flat_map(|(caller_id, function)| {
332-
let called_functions = called_functions_vec(function);
333-
called_functions.into_iter().map(|callee_id| (*caller_id, callee_id))
334-
})
335-
.fold(
336-
// Make sure an entry exists even for ones that don't call anything.
337-
ssa.functions.keys().map(|id| (*id, BTreeMap::new())).collect(),
338-
|mut acc, (caller_id, callee_id)| {
339-
let callees = acc.entry(caller_id).or_default();
340-
*callees.entry(callee_id).or_default() += 1;
341-
acc
342-
},
343-
)
344-
}
345213

346-
/// Compute something like a topological order of the functions, starting with the ones
347-
/// that do not call any other functions, going towards the entry points. When cycles
348-
/// are detected, take the one which are called by the most to break the ties.
349-
///
350-
/// This can be used to simplify the most often called functions first.
351-
///
352-
/// Returns the functions paired with their own as well as transitive weight,
353-
/// which accumulates the weight of all the functions they call, as well as own.
354-
pub(super) fn compute_bottom_up_order(ssa: &Ssa) -> Vec<(FunctionId, (usize, usize))> {
355-
let mut order = Vec::new();
356-
let mut visited = HashSet::new();
357-
358-
// Call graph which we'll repeatedly prune to find the "leaves".
359-
let mut callees = compute_callees(ssa);
360-
let callers = compute_callers(ssa);
361-
362-
// Number of times a function is called, used to break cycles in the call graph by popping the next candidate.
363-
let mut times_called = compute_times_called(&callers).into_iter().collect::<Vec<_>>();
364-
times_called.sort_by_key(|(id, cnt)| {
365-
// Sort by called the *least* by others, as these are less likely to cut the graph when removed.
366-
let called_desc = -(*cnt as i64);
367-
// Sort entries first (last to be popped).
368-
let is_entry_asc = -called_desc.signum();
369-
// Finally break ties by ID.
370-
(is_entry_asc, called_desc, *id)
371-
});
372-
373-
// Start with the weight of the functions in isolation, then accumulate as we pop off the ones they call.
374-
let own_weights = ssa
375-
.functions
376-
.iter()
377-
.map(|(id, f)| (*id, compute_function_own_weight(f)))
378-
.collect::<HashMap<_, _>>();
379-
let mut weights = own_weights.clone();
380-
381-
// Seed the queue with functions that don't call anything.
382-
let mut queue = callees
383-
.iter()
384-
.filter_map(|(id, callees)| callees.is_empty().then_some(*id))
385-
.collect::<VecDeque<_>>();
386-
387-
loop {
388-
while let Some(id) = queue.pop_front() {
389-
// Pull the current weight of yet-to-be emitted callees (a nod to mutual recursion).
390-
for (callee, cnt) in &callees[&id] {
391-
if *callee != id {
392-
weights[&id] = weights[&id].saturating_add(cnt.saturating_mul(weights[callee]));
393-
}
394-
}
395-
// Own weight plus the weights accumulated from callees.
396-
let weight = weights[&id];
397-
let own_weight = own_weights[&id];
398-
399-
// Emit the function.
400-
order.push((id, (own_weight, weight)));
401-
visited.insert(id);
402-
403-
// Update the callers of this function.
404-
for (caller, cnt) in &callers[&id] {
405-
// Update the weight of the caller with the weight of this function.
406-
weights[caller] = weights[caller].saturating_add(cnt.saturating_mul(weight));
407-
// Remove this function from the callees of the caller.
408-
let callees = callees.get_mut(caller).unwrap();
409-
callees.remove(&id);
410-
// If the caller doesn't call any other function, enqueue it,
411-
// unless it's the entry function, which is never called by anything, so it should be last.
412-
if callees.is_empty() && !visited.contains(caller) && !callers[caller].is_empty() {
413-
queue.push_back(*caller);
414-
}
415-
}
416-
}
417-
// If we ran out of the queue, maybe there is a cycle; take the next most called function.
418-
while let Some((id, _)) = times_called.pop() {
419-
if !visited.contains(&id) {
420-
queue.push_back(id);
421-
break;
422-
}
423-
}
424-
if times_called.is_empty() && queue.is_empty() {
425-
assert_eq!(order.len(), callers.len());
426-
return order;
427-
}
428-
}
429-
}
430-
431-
/// Traverse the call graph starting from a given function, marking function to be retained if they are:
432-
/// * recursive functions, or
433-
/// * the cost of inlining outweighs the cost of not doing so
434-
fn mark_functions_to_retain_recursive(
435-
ssa: &Ssa,
436-
inline_no_predicates_functions: bool,
437-
aggressiveness: i64,
438-
times_called: &HashMap<FunctionId, usize>,
439-
inline_infos: &mut InlineInfos,
440-
mut explored_functions: im::HashSet<FunctionId>,
441-
func: FunctionId,
442-
) {
443-
// Check if we have set any of the fields this method touches.
444-
let decided = |inline_infos: &InlineInfos| {
445-
inline_infos
446-
.get(&func)
447-
.map(|info| info.is_recursive || info.should_inline || info.weight != 0)
448-
.unwrap_or_default()
449-
};
450-
451-
// Check if we have already decided on this function
452-
if decided(inline_infos) {
453-
return;
454-
}
455-
456-
// If recursive, this function won't be inlined
457-
if explored_functions.contains(&func) {
458-
inline_infos.entry(func).or_default().is_recursive = true;
459-
return;
460-
}
461-
explored_functions.insert(func);
462-
463-
// Decide on dependencies first, so we know their weight.
464-
let called_functions = called_functions_vec(&ssa.functions[&func]);
465-
for callee in &called_functions {
466-
mark_functions_to_retain_recursive(
467-
ssa,
468-
inline_no_predicates_functions,
469-
aggressiveness,
470-
times_called,
471-
inline_infos,
472-
explored_functions.clone(),
473-
*callee,
474-
);
475-
}
476-
477-
// We could have decided on this function while deciding on dependencies
478-
// if the function is recursive.
479-
if decided(inline_infos) {
480-
return;
481-
}
482-
483-
// We'll use some heuristics to decide whether to inline or not.
484-
// We compute the weight (roughly the number of instructions) of the function after inlining
485-
// And the interface cost of the function (the inherent cost at the callsite, roughly the number of args and returns)
486-
// We then can compute an approximation of the cost of inlining vs the cost of retaining the function
487-
// We do this computation using saturating i64s to avoid overflows,
488-
// and because we want to calculate a difference which can be negative.
489-
490-
// Total weight of functions called by this one, unless we decided not to inline them.
491-
// Callees which appear multiple times would be inlined multiple times.
492-
let inlined_function_weights: i64 = called_functions.iter().fold(0, |acc, callee| {
493-
let info = &inline_infos[callee];
494-
// If the callee is not going to be inlined then we can ignore its cost.
495-
if info.should_inline {
496-
acc.saturating_add(info.weight)
497-
} else {
498-
acc
499-
}
500-
});
501-
502-
let this_function_weight = inlined_function_weights
503-
.saturating_add(compute_function_own_weight(&ssa.functions[&func]) as i64);
504-
505-
let interface_cost = compute_function_interface_cost(&ssa.functions[&func]) as i64;
506-
507-
let times_called = times_called[&func] as i64;
508-
509-
let inline_cost = times_called.saturating_mul(this_function_weight);
510-
let retain_cost = times_called.saturating_mul(interface_cost) + this_function_weight;
511-
let net_cost = inline_cost.saturating_sub(retain_cost);
512-
513-
let runtime = ssa.functions[&func].runtime();
514-
// We inline if the aggressiveness is higher than inline cost minus the retain cost
515-
// If aggressiveness is infinite, we'll always inline
516-
// If aggressiveness is 0, we'll inline when the inline cost is lower than the retain cost
517-
// If aggressiveness is minus infinity, we'll never inline (other than in the mandatory cases)
518-
let should_inline = (net_cost < aggressiveness)
519-
|| runtime.is_inline_always()
520-
|| (runtime.is_no_predicates() && inline_no_predicates_functions);
521-
522-
let info = inline_infos.entry(func).or_default();
523-
info.should_inline = should_inline;
524-
info.weight = this_function_weight;
525-
info.cost = net_cost;
526-
}
527-
528-
/// Mark Brillig functions that should not be inlined because they are recursive or expensive.
529-
fn mark_brillig_functions_to_retain(
530-
ssa: &Ssa,
531-
inline_no_predicates_functions: bool,
532-
aggressiveness: i64,
533-
times_called: &HashMap<FunctionId, usize>,
534-
inline_infos: &mut InlineInfos,
535-
) {
536-
let brillig_entry_points = inline_infos
537-
.iter()
538-
.filter_map(|(id, info)| info.is_brillig_entry_point.then_some(*id))
539-
.collect::<Vec<_>>();
540-
541-
for entry_point in brillig_entry_points {
542-
mark_functions_to_retain_recursive(
543-
ssa,
544-
inline_no_predicates_functions,
545-
aggressiveness,
546-
times_called,
547-
inline_infos,
548-
im::HashSet::default(),
549-
entry_point,
550-
);
551-
}
552-
}
553-
554-
/// Compute a weight of a function based on the number of instructions in its reachable blocks.
555-
fn compute_function_own_weight(func: &Function) -> usize {
556-
let mut weight = 0;
557-
for block_id in func.reachable_blocks() {
558-
weight += func.dfg[block_id].instructions().len() + 1; // We add one for the terminator
559-
}
560-
// We use an approximation of the average increase in instruction ratio from SSA to Brillig
561-
// In order to get the actual weight we'd need to codegen this function to brillig.
562-
weight
563-
}
564-
565-
/// Compute interface cost of a function based on the number of inputs and outputs.
566-
fn compute_function_interface_cost(func: &Function) -> usize {
567-
func.parameters().len() + func.returns().len()
568-
}
569214

570215
impl InlineContext {
571216
/// Create a new context object for the function inlining pass.

0 commit comments

Comments
 (0)