Skip to content

Commit c78e3c7

Browse files
authored
Rollup merge of #107411 - cjgillot:dataflow-discriminant, r=oli-obk
Handle discriminant in DataflowConstProp cc ``@jachris`` r? ``@JakobDegen`` This PR attempts to extend the DataflowConstProp pass to handle propagation of discriminants. We handle this by adding 2 new variants to `TrackElem`: `TrackElem::Variant` for enum variants and `TrackElem::Discriminant` for the enum discriminant pseudo-place. The difficulty is that the enum discriminant and enum variants may alias each another. This is the issue of the `Option<NonZeroUsize>` test, which is the equivalent of rust-lang/unsafe-code-guidelines#84 with a direct write. To handle that, we generalize the flood process to flood all the potentially aliasing places. In particular: - any write to `(PLACE as Variant)`, either direct or through a projection, floods `(PLACE as OtherVariant)` for all other variants and `discriminant(PLACE)`; - `SetDiscriminant(PLACE)` floods `(PLACE as Variant)` for each variant. This implies that flooding is not hierarchical any more, and that an assignment to a non-tracked place may need to flood a tracked place. This is handled by `for_each_aliasing_place` which generalizes `preorder_invoke`. As we deaggregate enums by putting `SetDiscriminant` last, this allows to propagate the value of the discriminant. This refactor will allow to make #107009 able to handle discriminants too.
2 parents a110cf5 + 09797a4 commit c78e3c7

File tree

9 files changed

+408
-112
lines changed

9 files changed

+408
-112
lines changed

compiler/rustc_middle/src/mir/mod.rs

+8
Original file line numberDiff line numberDiff line change
@@ -1642,6 +1642,14 @@ impl<'tcx> PlaceRef<'tcx> {
16421642
}
16431643
}
16441644

1645+
/// Returns `true` if this `Place` contains a `Deref` projection.
1646+
///
1647+
/// If `Place::is_indirect` returns false, the caller knows that the `Place` refers to the
1648+
/// same region of memory as its base.
1649+
pub fn is_indirect(&self) -> bool {
1650+
self.projection.iter().any(|elem| elem.is_indirect())
1651+
}
1652+
16451653
/// If MirPhase >= Derefered and if projection contains Deref,
16461654
/// It's guaranteed to be in the first place
16471655
pub fn has_deref(&self) -> bool {

compiler/rustc_mir_dataflow/src/impls/borrowed_locals.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,9 @@ where
121121
// for now. See discussion on [#61069].
122122
//
123123
// [#61069]: https://github.com/rust-lang/rust/pull/61069
124-
self.trans.gen(dropped_place.local);
124+
if !dropped_place.is_indirect() {
125+
self.trans.gen(dropped_place.local);
126+
}
125127
}
126128

127129
TerminatorKind::Abort

compiler/rustc_mir_dataflow/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#![feature(associated_type_defaults)]
22
#![feature(box_patterns)]
33
#![feature(exact_size_is_empty)]
4+
#![feature(let_chains)]
45
#![feature(min_specialization)]
56
#![feature(once_cell)]
67
#![feature(stmt_expr_attributes)]

compiler/rustc_mir_dataflow/src/value_analysis.rs

+214-61
Large diffs are not rendered by default.

compiler/rustc_mir_transform/src/dataflow_const_prop.rs

+95-32
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use rustc_mir_dataflow::value_analysis::{Map, State, TrackElem, ValueAnalysis, V
1313
use rustc_mir_dataflow::{lattice::FlatSet, Analysis, ResultsVisitor, SwitchIntEdgeEffects};
1414
use rustc_span::DUMMY_SP;
1515
use rustc_target::abi::Align;
16+
use rustc_target::abi::VariantIdx;
1617

1718
use crate::MirPass;
1819

@@ -30,14 +31,12 @@ impl<'tcx> MirPass<'tcx> for DataflowConstProp {
3031

3132
#[instrument(skip_all level = "debug")]
3233
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
34+
debug!(def_id = ?body.source.def_id());
3335
if tcx.sess.mir_opt_level() < 4 && body.basic_blocks.len() > BLOCK_LIMIT {
3436
debug!("aborted dataflow const prop due too many basic blocks");
3537
return;
3638
}
3739

38-
// Decide which places to track during the analysis.
39-
let map = Map::from_filter(tcx, body, Ty::is_scalar);
40-
4140
// We want to have a somewhat linear runtime w.r.t. the number of statements/terminators.
4241
// Let's call this number `n`. Dataflow analysis has `O(h*n)` transfer function
4342
// applications, where `h` is the height of the lattice. Because the height of our lattice
@@ -46,10 +45,10 @@ impl<'tcx> MirPass<'tcx> for DataflowConstProp {
4645
// `O(num_nodes * tracked_places * n)` in terms of time complexity. Since the number of
4746
// map nodes is strongly correlated to the number of tracked places, this becomes more or
4847
// less `O(n)` if we place a constant limit on the number of tracked places.
49-
if tcx.sess.mir_opt_level() < 4 && map.tracked_places() > PLACE_LIMIT {
50-
debug!("aborted dataflow const prop due to too many tracked places");
51-
return;
52-
}
48+
let place_limit = if tcx.sess.mir_opt_level() < 4 { Some(PLACE_LIMIT) } else { None };
49+
50+
// Decide which places to track during the analysis.
51+
let map = Map::from_filter(tcx, body, Ty::is_scalar, place_limit);
5352

5453
// Perform the actual dataflow analysis.
5554
let analysis = ConstAnalysis::new(tcx, body, map);
@@ -63,14 +62,31 @@ impl<'tcx> MirPass<'tcx> for DataflowConstProp {
6362
}
6463
}
6564

66-
struct ConstAnalysis<'tcx> {
65+
struct ConstAnalysis<'a, 'tcx> {
6766
map: Map,
6867
tcx: TyCtxt<'tcx>,
68+
local_decls: &'a LocalDecls<'tcx>,
6969
ecx: InterpCx<'tcx, 'tcx, DummyMachine>,
7070
param_env: ty::ParamEnv<'tcx>,
7171
}
7272

73-
impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
73+
impl<'tcx> ConstAnalysis<'_, 'tcx> {
74+
fn eval_discriminant(
75+
&self,
76+
enum_ty: Ty<'tcx>,
77+
variant_index: VariantIdx,
78+
) -> Option<ScalarTy<'tcx>> {
79+
if !enum_ty.is_enum() {
80+
return None;
81+
}
82+
let discr = enum_ty.discriminant_for_variant(self.tcx, variant_index)?;
83+
let discr_layout = self.tcx.layout_of(self.param_env.and(discr.ty)).ok()?;
84+
let discr_value = Scalar::try_from_uint(discr.val, discr_layout.size)?;
85+
Some(ScalarTy(discr_value, discr.ty))
86+
}
87+
}
88+
89+
impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> {
7490
type Value = FlatSet<ScalarTy<'tcx>>;
7591

7692
const NAME: &'static str = "ConstAnalysis";
@@ -79,6 +95,25 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
7995
&self.map
8096
}
8197

98+
fn handle_statement(&self, statement: &Statement<'tcx>, state: &mut State<Self::Value>) {
99+
match statement.kind {
100+
StatementKind::SetDiscriminant { box ref place, variant_index } => {
101+
state.flood_discr(place.as_ref(), &self.map);
102+
if self.map.find_discr(place.as_ref()).is_some() {
103+
let enum_ty = place.ty(self.local_decls, self.tcx).ty;
104+
if let Some(discr) = self.eval_discriminant(enum_ty, variant_index) {
105+
state.assign_discr(
106+
place.as_ref(),
107+
ValueOrPlace::Value(FlatSet::Elem(discr)),
108+
&self.map,
109+
);
110+
}
111+
}
112+
}
113+
_ => self.super_statement(statement, state),
114+
}
115+
}
116+
82117
fn handle_assign(
83118
&self,
84119
target: Place<'tcx>,
@@ -87,36 +122,47 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
87122
) {
88123
match rvalue {
89124
Rvalue::Aggregate(kind, operands) => {
90-
let target = self.map().find(target.as_ref());
91-
if let Some(target) = target {
92-
state.flood_idx_with(target, self.map(), FlatSet::Bottom);
93-
let field_based = match **kind {
94-
AggregateKind::Tuple | AggregateKind::Closure(..) => true,
95-
AggregateKind::Adt(def_id, ..) => {
96-
matches!(self.tcx.def_kind(def_id), DefKind::Struct)
125+
state.flood_with(target.as_ref(), self.map(), FlatSet::Bottom);
126+
if let Some(target_idx) = self.map().find(target.as_ref()) {
127+
let (variant_target, variant_index) = match **kind {
128+
AggregateKind::Tuple | AggregateKind::Closure(..) => {
129+
(Some(target_idx), None)
97130
}
98-
_ => false,
131+
AggregateKind::Adt(def_id, variant_index, ..) => {
132+
match self.tcx.def_kind(def_id) {
133+
DefKind::Struct => (Some(target_idx), None),
134+
DefKind::Enum => (Some(target_idx), Some(variant_index)),
135+
_ => (None, None),
136+
}
137+
}
138+
_ => (None, None),
99139
};
100-
if field_based {
140+
if let Some(target) = variant_target {
101141
for (field_index, operand) in operands.iter().enumerate() {
102142
if let Some(field) = self
103143
.map()
104144
.apply(target, TrackElem::Field(Field::from_usize(field_index)))
105145
{
106146
let result = self.handle_operand(operand, state);
107-
state.assign_idx(field, result, self.map());
147+
state.insert_idx(field, result, self.map());
108148
}
109149
}
110150
}
151+
if let Some(variant_index) = variant_index
152+
&& let Some(discr_idx) = self.map().apply(target_idx, TrackElem::Discriminant)
153+
{
154+
let enum_ty = target.ty(self.local_decls, self.tcx).ty;
155+
if let Some(discr_val) = self.eval_discriminant(enum_ty, variant_index) {
156+
state.insert_value_idx(discr_idx, FlatSet::Elem(discr_val), &self.map);
157+
}
158+
}
111159
}
112160
}
113161
Rvalue::CheckedBinaryOp(op, box (left, right)) => {
162+
// Flood everything now, so we can use `insert_value_idx` directly later.
163+
state.flood(target.as_ref(), self.map());
164+
114165
let target = self.map().find(target.as_ref());
115-
if let Some(target) = target {
116-
// We should not track any projections other than
117-
// what is overwritten below, but just in case...
118-
state.flood_idx(target, self.map());
119-
}
120166

121167
let value_target = target
122168
.and_then(|target| self.map().apply(target, TrackElem::Field(0_u32.into())));
@@ -127,7 +173,8 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
127173
let (val, overflow) = self.binary_op(state, *op, left, right);
128174

129175
if let Some(value_target) = value_target {
130-
state.assign_idx(value_target, ValueOrPlace::Value(val), self.map());
176+
// We have flooded `target` earlier.
177+
state.insert_value_idx(value_target, val, self.map());
131178
}
132179
if let Some(overflow_target) = overflow_target {
133180
let overflow = match overflow {
@@ -142,11 +189,8 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
142189
}
143190
FlatSet::Bottom => FlatSet::Bottom,
144191
};
145-
state.assign_idx(
146-
overflow_target,
147-
ValueOrPlace::Value(overflow),
148-
self.map(),
149-
);
192+
// We have flooded `target` earlier.
193+
state.insert_value_idx(overflow_target, overflow, self.map());
150194
}
151195
}
152196
}
@@ -195,6 +239,9 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
195239
FlatSet::Bottom => ValueOrPlace::Value(FlatSet::Bottom),
196240
FlatSet::Top => ValueOrPlace::Value(FlatSet::Top),
197241
},
242+
Rvalue::Discriminant(place) => {
243+
ValueOrPlace::Value(state.get_discr(place.as_ref(), self.map()))
244+
}
198245
_ => self.super_rvalue(rvalue, state),
199246
}
200247
}
@@ -268,12 +315,13 @@ impl<'tcx> std::fmt::Debug for ScalarTy<'tcx> {
268315
}
269316
}
270317

271-
impl<'tcx> ConstAnalysis<'tcx> {
272-
pub fn new(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, map: Map) -> Self {
318+
impl<'a, 'tcx> ConstAnalysis<'a, 'tcx> {
319+
pub fn new(tcx: TyCtxt<'tcx>, body: &'a Body<'tcx>, map: Map) -> Self {
273320
let param_env = tcx.param_env(body.source.def_id());
274321
Self {
275322
map,
276323
tcx,
324+
local_decls: &body.local_decls,
277325
ecx: InterpCx::new(tcx, DUMMY_SP, param_env, DummyMachine),
278326
param_env: param_env,
279327
}
@@ -466,6 +514,21 @@ impl<'tcx, 'map, 'a> Visitor<'tcx> for OperandCollector<'tcx, 'map, 'a> {
466514
_ => (),
467515
}
468516
}
517+
518+
fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) {
519+
match rvalue {
520+
Rvalue::Discriminant(place) => {
521+
match self.state.get_discr(place.as_ref(), self.visitor.map) {
522+
FlatSet::Top => (),
523+
FlatSet::Elem(value) => {
524+
self.visitor.before_effect.insert((location, *place), value);
525+
}
526+
FlatSet::Bottom => (),
527+
}
528+
}
529+
_ => self.super_rvalue(rvalue, location),
530+
}
531+
}
469532
}
470533

471534
struct DummyMachine;

compiler/rustc_mir_transform/src/sroa.rs

+10-8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::MirPass;
2-
use rustc_index::bit_set::BitSet;
2+
use rustc_index::bit_set::{BitSet, GrowableBitSet};
33
use rustc_index::vec::IndexVec;
44
use rustc_middle::mir::patch::MirPatch;
55
use rustc_middle::mir::visit::*;
@@ -26,10 +26,12 @@ impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates {
2626
debug!(?replacements);
2727
let all_dead_locals = replace_flattened_locals(tcx, body, replacements);
2828
if !all_dead_locals.is_empty() {
29-
for local in excluded.indices() {
30-
excluded[local] |= all_dead_locals.contains(local);
31-
}
32-
excluded.raw.resize(body.local_decls.len(), false);
29+
excluded.union(&all_dead_locals);
30+
excluded = {
31+
let mut growable = GrowableBitSet::from(excluded);
32+
growable.ensure(body.local_decls.len());
33+
growable.into()
34+
};
3335
} else {
3436
break;
3537
}
@@ -44,11 +46,11 @@ impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates {
4446
/// - the locals is a union or an enum;
4547
/// - the local's address is taken, and thus the relative addresses of the fields are observable to
4648
/// client code.
47-
fn escaping_locals(excluded: &IndexVec<Local, bool>, body: &Body<'_>) -> BitSet<Local> {
49+
fn escaping_locals(excluded: &BitSet<Local>, body: &Body<'_>) -> BitSet<Local> {
4850
let mut set = BitSet::new_empty(body.local_decls.len());
4951
set.insert_range(RETURN_PLACE..=Local::from_usize(body.arg_count));
5052
for (local, decl) in body.local_decls().iter_enumerated() {
51-
if decl.ty.is_union() || decl.ty.is_enum() || excluded[local] {
53+
if decl.ty.is_union() || decl.ty.is_enum() || excluded.contains(local) {
5254
set.insert(local);
5355
}
5456
}
@@ -172,7 +174,7 @@ fn replace_flattened_locals<'tcx>(
172174
body: &mut Body<'tcx>,
173175
replacements: ReplacementMap<'tcx>,
174176
) -> BitSet<Local> {
175-
let mut all_dead_locals = BitSet::new_empty(body.local_decls.len());
177+
let mut all_dead_locals = BitSet::new_empty(replacements.fragments.len());
176178
for (local, replacements) in replacements.fragments.iter_enumerated() {
177179
if replacements.is_some() {
178180
all_dead_locals.insert(local);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
- // MIR for `mutate_discriminant` before DataflowConstProp
2+
+ // MIR for `mutate_discriminant` after DataflowConstProp
3+
4+
fn mutate_discriminant() -> u8 {
5+
let mut _0: u8; // return place in scope 0 at $DIR/enum.rs:+0:29: +0:31
6+
let mut _1: std::option::Option<NonZeroUsize>; // in scope 0 at $SRC_DIR/core/src/intrinsics/mir.rs:LL:COL
7+
let mut _2: isize; // in scope 0 at $SRC_DIR/core/src/intrinsics/mir.rs:LL:COL
8+
9+
bb0: {
10+
discriminant(_1) = 1; // scope 0 at $DIR/enum.rs:+4:13: +4:34
11+
(((_1 as variant#1).0: NonZeroUsize).0: usize) = const 0_usize; // scope 0 at $DIR/enum.rs:+6:13: +6:64
12+
_2 = discriminant(_1); // scope 0 at $SRC_DIR/core/src/intrinsics/mir.rs:LL:COL
13+
switchInt(_2) -> [0: bb1, otherwise: bb2]; // scope 0 at $DIR/enum.rs:+9:13: +12:14
14+
}
15+
16+
bb1: {
17+
_0 = const 1_u8; // scope 0 at $DIR/enum.rs:+15:13: +15:20
18+
return; // scope 0 at $DIR/enum.rs:+16:13: +16:21
19+
}
20+
21+
bb2: {
22+
_0 = const 2_u8; // scope 0 at $DIR/enum.rs:+19:13: +19:20
23+
unreachable; // scope 0 at $DIR/enum.rs:+20:13: +20:26
24+
}
25+
}
26+
+42-3
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,52 @@
11
// unit-test: DataflowConstProp
22

3-
// Not trackable, because variants could be aliased.
3+
#![feature(custom_mir, core_intrinsics, rustc_attrs)]
4+
5+
use std::intrinsics::mir::*;
6+
47
enum E {
58
V1(i32),
69
V2(i32)
710
}
811

9-
// EMIT_MIR enum.main.DataflowConstProp.diff
10-
fn main() {
12+
// EMIT_MIR enum.simple.DataflowConstProp.diff
13+
fn simple() {
1114
let e = E::V1(0);
1215
let x = match e { E::V1(x) => x, E::V2(x) => x };
1316
}
17+
18+
#[rustc_layout_scalar_valid_range_start(1)]
19+
#[rustc_nonnull_optimization_guaranteed]
20+
struct NonZeroUsize(usize);
21+
22+
// EMIT_MIR enum.mutate_discriminant.DataflowConstProp.diff
23+
#[custom_mir(dialect = "runtime", phase = "post-cleanup")]
24+
fn mutate_discriminant() -> u8 {
25+
mir!(
26+
let x: Option<NonZeroUsize>;
27+
{
28+
SetDiscriminant(x, 1);
29+
// This assignment overwrites the niche in which the discriminant is stored.
30+
place!(Field(Field(Variant(x, 1), 0), 0)) = 0_usize;
31+
// So we cannot know the value of this discriminant.
32+
let a = Discriminant(x);
33+
match a {
34+
0 => bb1,
35+
_ => bad,
36+
}
37+
}
38+
bb1 = {
39+
RET = 1;
40+
Return()
41+
}
42+
bad = {
43+
RET = 2;
44+
Unreachable()
45+
}
46+
)
47+
}
48+
49+
fn main() {
50+
simple();
51+
mutate_discriminant();
52+
}

0 commit comments

Comments
 (0)