Skip to content

Commit c241e14

Browse files
committed
Auto merge of rust-lang#136593 - lukas-code:ty-value-perf, r=oli-obk
valtree performance tuning Summary: This PR makes type checking of code with many type-level constants faster. After rust-lang#136180 was merged, we observed a small perf regression (rust-lang#136318 (comment)). This happened because that PR introduced additional copies in the fast reject code path for consts, which is very hot for certain crates: https://github.com/rust-lang/rust/blob/6c1d960d88dd3755548b3818630acb63fa98187e/compiler/rustc_type_ir/src/fast_reject.rs#L486-L487 This PR improves the performance again by properly interning the valtrees so that copying and comparing them becomes faster. This will become especially useful with `feature(adt_const_params)`, so the fast reject code doesn't have to do a deep compare of the valtrees. Note that we can't just compare the interned consts themselves in the fast reject, because sometimes `'static` lifetimes in the type are be replaced with inference variables (due to canonicalization) on one side but not the other. A less invasive alternative that I considered is simply avoiding copies introduced by rust-lang#136180 and comparing the valtrees it in-place (see commit: rust-lang@9e91e50 / perf results: rust-lang#136593 (comment)), however that was still measurably slower than interning. There are some minor regressions in secondary benchmarks: These happen due to changes in memory allocations and seem acceptable to me. The crates that make heavy use of valtrees show no significant changes in memory usage.
2 parents 54cdc75 + b722d5d commit c241e14

File tree

15 files changed

+158
-129
lines changed

15 files changed

+158
-129
lines changed

compiler/rustc_const_eval/src/const_eval/valtrees.rs

+21-27
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use rustc_abi::{BackendRepr, VariantIdx};
22
use rustc_data_structures::stack::ensure_sufficient_stack;
33
use rustc_middle::mir::interpret::{EvalToValTreeResult, GlobalId, ReportedErrorInfo};
44
use rustc_middle::ty::layout::{LayoutCx, LayoutOf, TyAndLayout};
5-
use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt};
5+
use rustc_middle::ty::{self, Ty, TyCtxt};
66
use rustc_middle::{bug, mir};
77
use rustc_span::DUMMY_SP;
88
use tracing::{debug, instrument, trace};
@@ -21,38 +21,36 @@ use crate::interpret::{
2121
fn branches<'tcx>(
2222
ecx: &CompileTimeInterpCx<'tcx>,
2323
place: &MPlaceTy<'tcx>,
24-
n: usize,
24+
field_count: usize,
2525
variant: Option<VariantIdx>,
2626
num_nodes: &mut usize,
2727
) -> ValTreeCreationResult<'tcx> {
2828
let place = match variant {
2929
Some(variant) => ecx.project_downcast(place, variant).unwrap(),
3030
None => place.clone(),
3131
};
32-
let variant = variant.map(|variant| Some(ty::ValTree::Leaf(ScalarInt::from(variant.as_u32()))));
33-
debug!(?place, ?variant);
32+
debug!(?place);
3433

35-
let mut fields = Vec::with_capacity(n);
36-
for i in 0..n {
37-
let field = ecx.project_field(&place, i).unwrap();
38-
let valtree = const_to_valtree_inner(ecx, &field, num_nodes)?;
39-
fields.push(Some(valtree));
40-
}
34+
let mut branches = Vec::with_capacity(field_count + variant.is_some() as usize);
4135

4236
// For enums, we prepend their variant index before the variant's fields so we can figure out
4337
// the variant again when just seeing a valtree.
44-
let branches = variant
45-
.into_iter()
46-
.chain(fields.into_iter())
47-
.collect::<Option<Vec<_>>>()
48-
.expect("should have already checked for errors in ValTree creation");
38+
if let Some(variant) = variant {
39+
branches.push(ty::ValTree::from_scalar_int(*ecx.tcx, variant.as_u32().into()));
40+
}
41+
42+
for i in 0..field_count {
43+
let field = ecx.project_field(&place, i).unwrap();
44+
let valtree = const_to_valtree_inner(ecx, &field, num_nodes)?;
45+
branches.push(valtree);
46+
}
4947

5048
// Have to account for ZSTs here
5149
if branches.len() == 0 {
5250
*num_nodes += 1;
5351
}
5452

55-
Ok(ty::ValTree::Branch(ecx.tcx.arena.alloc_from_iter(branches)))
53+
Ok(ty::ValTree::from_branches(*ecx.tcx, branches))
5654
}
5755

5856
#[instrument(skip(ecx), level = "debug")]
@@ -70,7 +68,7 @@ fn slice_branches<'tcx>(
7068
elems.push(valtree);
7169
}
7270

73-
Ok(ty::ValTree::Branch(ecx.tcx.arena.alloc_from_iter(elems)))
71+
Ok(ty::ValTree::from_branches(*ecx.tcx, elems))
7472
}
7573

7674
#[instrument(skip(ecx), level = "debug")]
@@ -79,6 +77,7 @@ fn const_to_valtree_inner<'tcx>(
7977
place: &MPlaceTy<'tcx>,
8078
num_nodes: &mut usize,
8179
) -> ValTreeCreationResult<'tcx> {
80+
let tcx = *ecx.tcx;
8281
let ty = place.layout.ty;
8382
debug!("ty kind: {:?}", ty.kind());
8483

@@ -89,14 +88,14 @@ fn const_to_valtree_inner<'tcx>(
8988
match ty.kind() {
9089
ty::FnDef(..) => {
9190
*num_nodes += 1;
92-
Ok(ty::ValTree::zst())
91+
Ok(ty::ValTree::zst(tcx))
9392
}
9493
ty::Bool | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::Char => {
9594
let val = ecx.read_immediate(place).unwrap();
9695
let val = val.to_scalar_int().unwrap();
9796
*num_nodes += 1;
9897

99-
Ok(ty::ValTree::Leaf(val))
98+
Ok(ty::ValTree::from_scalar_int(tcx, val))
10099
}
101100

102101
ty::Pat(base, ..) => {
@@ -127,7 +126,7 @@ fn const_to_valtree_inner<'tcx>(
127126
return Err(ValTreeCreationError::NonSupportedType(ty));
128127
};
129128
// It's just a ScalarInt!
130-
Ok(ty::ValTree::Leaf(val))
129+
Ok(ty::ValTree::from_scalar_int(tcx, val))
131130
}
132131

133132
// Technically we could allow function pointers (represented as `ty::Instance`), but this is not guaranteed to
@@ -287,16 +286,11 @@ pub fn valtree_to_const_value<'tcx>(
287286
// FIXME: Does this need an example?
288287
match *cv.ty.kind() {
289288
ty::FnDef(..) => {
290-
assert!(cv.valtree.unwrap_branch().is_empty());
289+
assert!(cv.valtree.is_zst());
291290
mir::ConstValue::ZeroSized
292291
}
293292
ty::Bool | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::Char | ty::RawPtr(_, _) => {
294-
match cv.valtree {
295-
ty::ValTree::Leaf(scalar_int) => mir::ConstValue::Scalar(Scalar::Int(scalar_int)),
296-
ty::ValTree::Branch(_) => bug!(
297-
"ValTrees for Bool, Int, Uint, Float, Char or RawPtr should have the form ValTree::Leaf"
298-
),
299-
}
293+
mir::ConstValue::Scalar(Scalar::Int(cv.valtree.unwrap_leaf()))
300294
}
301295
ty::Pat(ty, _) => {
302296
let cv = ty::Value { valtree: cv.valtree, ty };

compiler/rustc_hir_analysis/src/hir_ty_lowering/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -2161,7 +2161,7 @@ impl<'tcx> dyn HirTyLowerer<'tcx> + '_ {
21612161
did,
21622162
path.segments.last().unwrap(),
21632163
);
2164-
ty::Const::new_value(tcx, ty::ValTree::zst(), Ty::new_fn_def(tcx, did, args))
2164+
ty::Const::zero_sized(tcx, Ty::new_fn_def(tcx, did, args))
21652165
}
21662166

21672167
// Exhaustive match to be clear about what exactly we're considering to be

compiler/rustc_middle/src/arena.rs

+1
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ macro_rules! arena_types {
9090
[] autodiff_item: rustc_ast::expand::autodiff_attrs::AutoDiffItem,
9191
[] ordered_name_set: rustc_data_structures::fx::FxIndexSet<rustc_span::Symbol>,
9292
[] pats: rustc_middle::ty::PatternKind<'tcx>,
93+
[] valtree: rustc_middle::ty::ValTreeKind<'tcx>,
9394

9495
// Note that this deliberately duplicates items in the `rustc_hir::arena`,
9596
// since we need to allocate this type on both the `rustc_hir` arena

compiler/rustc_middle/src/ty/codec.rs

+9-6
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,12 @@ impl<'tcx, E: TyEncoder<I = TyCtxt<'tcx>>> Encodable<E> for ty::Pattern<'tcx> {
146146
}
147147
}
148148

149+
impl<'tcx, E: TyEncoder<I = TyCtxt<'tcx>>> Encodable<E> for ty::ValTree<'tcx> {
150+
fn encode(&self, e: &mut E) {
151+
self.0.0.encode(e);
152+
}
153+
}
154+
149155
impl<'tcx, E: TyEncoder<I = TyCtxt<'tcx>>> Encodable<E> for ConstAllocation<'tcx> {
150156
fn encode(&self, e: &mut E) {
151157
self.inner().encode(e)
@@ -355,12 +361,9 @@ impl<'tcx, D: TyDecoder<I = TyCtxt<'tcx>>> Decodable<D> for ty::Pattern<'tcx> {
355361
}
356362
}
357363

358-
impl<'tcx, D: TyDecoder<I = TyCtxt<'tcx>>> RefDecodable<'tcx, D> for [ty::ValTree<'tcx>] {
359-
fn decode(decoder: &mut D) -> &'tcx Self {
360-
decoder
361-
.interner()
362-
.arena
363-
.alloc_from_iter((0..decoder.read_usize()).map(|_| Decodable::decode(decoder)))
364+
impl<'tcx, D: TyDecoder<I = TyCtxt<'tcx>>> Decodable<D> for ty::ValTree<'tcx> {
365+
fn decode(decoder: &mut D) -> Self {
366+
decoder.interner().intern_valtree(Decodable::decode(decoder))
364367
}
365368
}
366369

compiler/rustc_middle/src/ty/consts.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ pub type ConstKind<'tcx> = ir::ConstKind<TyCtxt<'tcx>>;
2020
pub type UnevaluatedConst<'tcx> = ir::UnevaluatedConst<TyCtxt<'tcx>>;
2121

2222
#[cfg(target_pointer_width = "64")]
23-
rustc_data_structures::static_assert_size!(ConstKind<'_>, 32);
23+
rustc_data_structures::static_assert_size!(ConstKind<'_>, 24);
2424

2525
#[derive(Copy, Clone, PartialEq, Eq, Hash, HashStable)]
2626
#[rustc_pass_by_value]
@@ -190,15 +190,15 @@ impl<'tcx> Const<'tcx> {
190190
.size;
191191
ty::Const::new_value(
192192
tcx,
193-
ty::ValTree::from_scalar_int(ScalarInt::try_from_uint(bits, size).unwrap()),
193+
ty::ValTree::from_scalar_int(tcx, ScalarInt::try_from_uint(bits, size).unwrap()),
194194
ty,
195195
)
196196
}
197197

198198
#[inline]
199199
/// Creates an interned zst constant.
200200
pub fn zero_sized(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> Self {
201-
ty::Const::new_value(tcx, ty::ValTree::zst(), ty)
201+
ty::Const::new_value(tcx, ty::ValTree::zst(tcx), ty)
202202
}
203203

204204
#[inline]

compiler/rustc_middle/src/ty/consts/valtree.rs

+71-27
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
use rustc_macros::{HashStable, TyDecodable, TyEncodable, TypeFoldable, TypeVisitable};
1+
use std::fmt;
2+
use std::ops::Deref;
3+
4+
use rustc_data_structures::intern::Interned;
5+
use rustc_macros::{HashStable, Lift, TyDecodable, TyEncodable, TypeFoldable, TypeVisitable};
26

37
use super::ScalarInt;
48
use crate::mir::interpret::Scalar;
@@ -16,9 +20,9 @@ use crate::ty::{self, Ty, TyCtxt};
1620
///
1721
/// `ValTree` does not have this problem with representation, as it only contains integers or
1822
/// lists of (nested) `ValTree`.
19-
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
23+
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
2024
#[derive(HashStable, TyEncodable, TyDecodable)]
21-
pub enum ValTree<'tcx> {
25+
pub enum ValTreeKind<'tcx> {
2226
/// integers, `bool`, `char` are represented as scalars.
2327
/// See the `ScalarInt` documentation for how `ScalarInt` guarantees that equal values
2428
/// of these types have the same representation.
@@ -33,58 +37,98 @@ pub enum ValTree<'tcx> {
3337
/// the fields of the variant.
3438
///
3539
/// ZST types are represented as an empty slice.
36-
Branch(&'tcx [ValTree<'tcx>]),
40+
Branch(Box<[ValTree<'tcx>]>),
3741
}
3842

39-
impl<'tcx> ValTree<'tcx> {
40-
pub fn zst() -> Self {
41-
Self::Branch(&[])
42-
}
43-
43+
impl<'tcx> ValTreeKind<'tcx> {
4444
#[inline]
45-
pub fn unwrap_leaf(self) -> ScalarInt {
45+
pub fn unwrap_leaf(&self) -> ScalarInt {
4646
match self {
47-
Self::Leaf(s) => s,
47+
Self::Leaf(s) => *s,
4848
_ => bug!("expected leaf, got {:?}", self),
4949
}
5050
}
5151

5252
#[inline]
53-
pub fn unwrap_branch(self) -> &'tcx [Self] {
53+
pub fn unwrap_branch(&self) -> &[ValTree<'tcx>] {
5454
match self {
55-
Self::Branch(branch) => branch,
55+
Self::Branch(branch) => &**branch,
5656
_ => bug!("expected branch, got {:?}", self),
5757
}
5858
}
5959

60-
pub fn from_raw_bytes<'a>(tcx: TyCtxt<'tcx>, bytes: &'a [u8]) -> Self {
61-
let branches = bytes.iter().map(|b| Self::Leaf(ScalarInt::from(*b)));
62-
let interned = tcx.arena.alloc_from_iter(branches);
60+
pub fn try_to_scalar(&self) -> Option<Scalar> {
61+
self.try_to_scalar_int().map(Scalar::Int)
62+
}
6363

64-
Self::Branch(interned)
64+
pub fn try_to_scalar_int(&self) -> Option<ScalarInt> {
65+
match self {
66+
Self::Leaf(s) => Some(*s),
67+
Self::Branch(_) => None,
68+
}
6569
}
6670

67-
pub fn from_scalar_int(i: ScalarInt) -> Self {
68-
Self::Leaf(i)
71+
pub fn try_to_branch(&self) -> Option<&[ValTree<'tcx>]> {
72+
match self {
73+
Self::Branch(branch) => Some(&**branch),
74+
Self::Leaf(_) => None,
75+
}
6976
}
77+
}
7078

71-
pub fn try_to_scalar(self) -> Option<Scalar> {
72-
self.try_to_scalar_int().map(Scalar::Int)
79+
/// An interned valtree. Use this rather than `ValTreeKind`, whenever possible.
80+
///
81+
/// See the docs of [`ValTreeKind`] or the [dev guide] for an explanation of this type.
82+
///
83+
/// [dev guide]: https://rustc-dev-guide.rust-lang.org/mir/index.html#valtrees
84+
#[derive(Copy, Clone, Hash, Eq, PartialEq)]
85+
#[derive(HashStable)]
86+
pub struct ValTree<'tcx>(pub(crate) Interned<'tcx, ValTreeKind<'tcx>>);
87+
88+
impl<'tcx> ValTree<'tcx> {
89+
/// Returns the zero-sized valtree: `Branch([])`.
90+
pub fn zst(tcx: TyCtxt<'tcx>) -> Self {
91+
tcx.consts.valtree_zst
7392
}
7493

75-
pub fn try_to_scalar_int(self) -> Option<ScalarInt> {
76-
match self {
77-
Self::Leaf(s) => Some(s),
78-
Self::Branch(_) => None,
79-
}
94+
pub fn is_zst(self) -> bool {
95+
matches!(*self, ValTreeKind::Branch(box []))
96+
}
97+
98+
pub fn from_raw_bytes(tcx: TyCtxt<'tcx>, bytes: &[u8]) -> Self {
99+
let branches = bytes.iter().map(|&b| Self::from_scalar_int(tcx, b.into()));
100+
Self::from_branches(tcx, branches)
101+
}
102+
103+
pub fn from_branches(tcx: TyCtxt<'tcx>, branches: impl IntoIterator<Item = Self>) -> Self {
104+
tcx.intern_valtree(ValTreeKind::Branch(branches.into_iter().collect()))
105+
}
106+
107+
pub fn from_scalar_int(tcx: TyCtxt<'tcx>, i: ScalarInt) -> Self {
108+
tcx.intern_valtree(ValTreeKind::Leaf(i))
109+
}
110+
}
111+
112+
impl<'tcx> Deref for ValTree<'tcx> {
113+
type Target = &'tcx ValTreeKind<'tcx>;
114+
115+
#[inline]
116+
fn deref(&self) -> &&'tcx ValTreeKind<'tcx> {
117+
&self.0.0
118+
}
119+
}
120+
121+
impl fmt::Debug for ValTree<'_> {
122+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123+
(**self).fmt(f)
80124
}
81125
}
82126

83127
/// A type-level constant value.
84128
///
85129
/// Represents a typed, fully evaluated constant.
86130
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
87-
#[derive(HashStable, TyEncodable, TyDecodable, TypeFoldable, TypeVisitable)]
131+
#[derive(HashStable, TyEncodable, TyDecodable, TypeFoldable, TypeVisitable, Lift)]
88132
pub struct Value<'tcx> {
89133
pub ty: Ty<'tcx>,
90134
pub valtree: ValTree<'tcx>,

0 commit comments

Comments
 (0)