@@ -13,6 +13,7 @@ use rustc_mir_dataflow::value_analysis::{Map, State, TrackElem, ValueAnalysis, V
13
13
use rustc_mir_dataflow:: { lattice:: FlatSet , Analysis , ResultsVisitor , SwitchIntEdgeEffects } ;
14
14
use rustc_span:: DUMMY_SP ;
15
15
use rustc_target:: abi:: Align ;
16
+ use rustc_target:: abi:: VariantIdx ;
16
17
17
18
use crate :: MirPass ;
18
19
@@ -30,14 +31,12 @@ impl<'tcx> MirPass<'tcx> for DataflowConstProp {
30
31
31
32
#[ instrument( skip_all level = "debug" ) ]
32
33
fn run_pass ( & self , tcx : TyCtxt < ' tcx > , body : & mut Body < ' tcx > ) {
34
+ debug ! ( def_id = ?body. source. def_id( ) ) ;
33
35
if tcx. sess . mir_opt_level ( ) < 4 && body. basic_blocks . len ( ) > BLOCK_LIMIT {
34
36
debug ! ( "aborted dataflow const prop due too many basic blocks" ) ;
35
37
return ;
36
38
}
37
39
38
- // Decide which places to track during the analysis.
39
- let map = Map :: from_filter ( tcx, body, Ty :: is_scalar) ;
40
-
41
40
// We want to have a somewhat linear runtime w.r.t. the number of statements/terminators.
42
41
// Let's call this number `n`. Dataflow analysis has `O(h*n)` transfer function
43
42
// applications, where `h` is the height of the lattice. Because the height of our lattice
@@ -46,10 +45,10 @@ impl<'tcx> MirPass<'tcx> for DataflowConstProp {
46
45
// `O(num_nodes * tracked_places * n)` in terms of time complexity. Since the number of
47
46
// map nodes is strongly correlated to the number of tracked places, this becomes more or
48
47
// less `O(n)` if we place a constant limit on the number of tracked places.
49
- if tcx. sess . mir_opt_level ( ) < 4 && map . tracked_places ( ) > PLACE_LIMIT {
50
- debug ! ( "aborted dataflow const prop due to too many tracked places" ) ;
51
- return ;
52
- }
48
+ let place_limit = if tcx. sess . mir_opt_level ( ) < 4 { Some ( PLACE_LIMIT ) } else { None } ;
49
+
50
+ // Decide which places to track during the analysis.
51
+ let map = Map :: from_filter ( tcx , body , Ty :: is_scalar , place_limit ) ;
53
52
54
53
// Perform the actual dataflow analysis.
55
54
let analysis = ConstAnalysis :: new ( tcx, body, map) ;
@@ -63,14 +62,31 @@ impl<'tcx> MirPass<'tcx> for DataflowConstProp {
63
62
}
64
63
}
65
64
66
- struct ConstAnalysis < ' tcx > {
65
+ struct ConstAnalysis < ' a , ' tcx > {
67
66
map : Map ,
68
67
tcx : TyCtxt < ' tcx > ,
68
+ local_decls : & ' a LocalDecls < ' tcx > ,
69
69
ecx : InterpCx < ' tcx , ' tcx , DummyMachine > ,
70
70
param_env : ty:: ParamEnv < ' tcx > ,
71
71
}
72
72
73
- impl < ' tcx > ValueAnalysis < ' tcx > for ConstAnalysis < ' tcx > {
73
+ impl < ' tcx > ConstAnalysis < ' _ , ' tcx > {
74
+ fn eval_discriminant (
75
+ & self ,
76
+ enum_ty : Ty < ' tcx > ,
77
+ variant_index : VariantIdx ,
78
+ ) -> Option < ScalarTy < ' tcx > > {
79
+ if !enum_ty. is_enum ( ) {
80
+ return None ;
81
+ }
82
+ let discr = enum_ty. discriminant_for_variant ( self . tcx , variant_index) ?;
83
+ let discr_layout = self . tcx . layout_of ( self . param_env . and ( discr. ty ) ) . ok ( ) ?;
84
+ let discr_value = Scalar :: try_from_uint ( discr. val , discr_layout. size ) ?;
85
+ Some ( ScalarTy ( discr_value, discr. ty ) )
86
+ }
87
+ }
88
+
89
+ impl < ' tcx > ValueAnalysis < ' tcx > for ConstAnalysis < ' _ , ' tcx > {
74
90
type Value = FlatSet < ScalarTy < ' tcx > > ;
75
91
76
92
const NAME : & ' static str = "ConstAnalysis" ;
@@ -79,6 +95,25 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
79
95
& self . map
80
96
}
81
97
98
+ fn handle_statement ( & self , statement : & Statement < ' tcx > , state : & mut State < Self :: Value > ) {
99
+ match statement. kind {
100
+ StatementKind :: SetDiscriminant { box ref place, variant_index } => {
101
+ state. flood_discr ( place. as_ref ( ) , & self . map ) ;
102
+ if self . map . find_discr ( place. as_ref ( ) ) . is_some ( ) {
103
+ let enum_ty = place. ty ( self . local_decls , self . tcx ) . ty ;
104
+ if let Some ( discr) = self . eval_discriminant ( enum_ty, variant_index) {
105
+ state. assign_discr (
106
+ place. as_ref ( ) ,
107
+ ValueOrPlace :: Value ( FlatSet :: Elem ( discr) ) ,
108
+ & self . map ,
109
+ ) ;
110
+ }
111
+ }
112
+ }
113
+ _ => self . super_statement ( statement, state) ,
114
+ }
115
+ }
116
+
82
117
fn handle_assign (
83
118
& self ,
84
119
target : Place < ' tcx > ,
@@ -87,36 +122,47 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
87
122
) {
88
123
match rvalue {
89
124
Rvalue :: Aggregate ( kind, operands) => {
90
- let target = self . map ( ) . find ( target. as_ref ( ) ) ;
91
- if let Some ( target) = target {
92
- state. flood_idx_with ( target, self . map ( ) , FlatSet :: Bottom ) ;
93
- let field_based = match * * kind {
94
- AggregateKind :: Tuple | AggregateKind :: Closure ( ..) => true ,
95
- AggregateKind :: Adt ( def_id, ..) => {
96
- matches ! ( self . tcx. def_kind( def_id) , DefKind :: Struct )
125
+ state. flood_with ( target. as_ref ( ) , self . map ( ) , FlatSet :: Bottom ) ;
126
+ if let Some ( target_idx) = self . map ( ) . find ( target. as_ref ( ) ) {
127
+ let ( variant_target, variant_index) = match * * kind {
128
+ AggregateKind :: Tuple | AggregateKind :: Closure ( ..) => {
129
+ ( Some ( target_idx) , None )
97
130
}
98
- _ => false ,
131
+ AggregateKind :: Adt ( def_id, variant_index, ..) => {
132
+ match self . tcx . def_kind ( def_id) {
133
+ DefKind :: Struct => ( Some ( target_idx) , None ) ,
134
+ DefKind :: Enum => ( Some ( target_idx) , Some ( variant_index) ) ,
135
+ _ => ( None , None ) ,
136
+ }
137
+ }
138
+ _ => ( None , None ) ,
99
139
} ;
100
- if field_based {
140
+ if let Some ( target ) = variant_target {
101
141
for ( field_index, operand) in operands. iter ( ) . enumerate ( ) {
102
142
if let Some ( field) = self
103
143
. map ( )
104
144
. apply ( target, TrackElem :: Field ( Field :: from_usize ( field_index) ) )
105
145
{
106
146
let result = self . handle_operand ( operand, state) ;
107
- state. assign_idx ( field, result, self . map ( ) ) ;
147
+ state. insert_idx ( field, result, self . map ( ) ) ;
108
148
}
109
149
}
110
150
}
151
+ if let Some ( variant_index) = variant_index
152
+ && let Some ( discr_idx) = self . map ( ) . apply ( target_idx, TrackElem :: Discriminant )
153
+ {
154
+ let enum_ty = target. ty ( self . local_decls , self . tcx ) . ty ;
155
+ if let Some ( discr_val) = self . eval_discriminant ( enum_ty, variant_index) {
156
+ state. insert_value_idx ( discr_idx, FlatSet :: Elem ( discr_val) , & self . map ) ;
157
+ }
158
+ }
111
159
}
112
160
}
113
161
Rvalue :: CheckedBinaryOp ( op, box ( left, right) ) => {
162
+ // Flood everything now, so we can use `insert_value_idx` directly later.
163
+ state. flood ( target. as_ref ( ) , self . map ( ) ) ;
164
+
114
165
let target = self . map ( ) . find ( target. as_ref ( ) ) ;
115
- if let Some ( target) = target {
116
- // We should not track any projections other than
117
- // what is overwritten below, but just in case...
118
- state. flood_idx ( target, self . map ( ) ) ;
119
- }
120
166
121
167
let value_target = target
122
168
. and_then ( |target| self . map ( ) . apply ( target, TrackElem :: Field ( 0_u32 . into ( ) ) ) ) ;
@@ -127,7 +173,8 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
127
173
let ( val, overflow) = self . binary_op ( state, * op, left, right) ;
128
174
129
175
if let Some ( value_target) = value_target {
130
- state. assign_idx ( value_target, ValueOrPlace :: Value ( val) , self . map ( ) ) ;
176
+ // We have flooded `target` earlier.
177
+ state. insert_value_idx ( value_target, val, self . map ( ) ) ;
131
178
}
132
179
if let Some ( overflow_target) = overflow_target {
133
180
let overflow = match overflow {
@@ -142,11 +189,8 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
142
189
}
143
190
FlatSet :: Bottom => FlatSet :: Bottom ,
144
191
} ;
145
- state. assign_idx (
146
- overflow_target,
147
- ValueOrPlace :: Value ( overflow) ,
148
- self . map ( ) ,
149
- ) ;
192
+ // We have flooded `target` earlier.
193
+ state. insert_value_idx ( overflow_target, overflow, self . map ( ) ) ;
150
194
}
151
195
}
152
196
}
@@ -195,6 +239,9 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
195
239
FlatSet :: Bottom => ValueOrPlace :: Value ( FlatSet :: Bottom ) ,
196
240
FlatSet :: Top => ValueOrPlace :: Value ( FlatSet :: Top ) ,
197
241
} ,
242
+ Rvalue :: Discriminant ( place) => {
243
+ ValueOrPlace :: Value ( state. get_discr ( place. as_ref ( ) , self . map ( ) ) )
244
+ }
198
245
_ => self . super_rvalue ( rvalue, state) ,
199
246
}
200
247
}
@@ -268,12 +315,13 @@ impl<'tcx> std::fmt::Debug for ScalarTy<'tcx> {
268
315
}
269
316
}
270
317
271
- impl < ' tcx > ConstAnalysis < ' tcx > {
272
- pub fn new ( tcx : TyCtxt < ' tcx > , body : & Body < ' tcx > , map : Map ) -> Self {
318
+ impl < ' a , ' tcx > ConstAnalysis < ' a , ' tcx > {
319
+ pub fn new ( tcx : TyCtxt < ' tcx > , body : & ' a Body < ' tcx > , map : Map ) -> Self {
273
320
let param_env = tcx. param_env ( body. source . def_id ( ) ) ;
274
321
Self {
275
322
map,
276
323
tcx,
324
+ local_decls : & body. local_decls ,
277
325
ecx : InterpCx :: new ( tcx, DUMMY_SP , param_env, DummyMachine ) ,
278
326
param_env : param_env,
279
327
}
@@ -466,6 +514,21 @@ impl<'tcx, 'map, 'a> Visitor<'tcx> for OperandCollector<'tcx, 'map, 'a> {
466
514
_ => ( ) ,
467
515
}
468
516
}
517
+
518
+ fn visit_rvalue ( & mut self , rvalue : & Rvalue < ' tcx > , location : Location ) {
519
+ match rvalue {
520
+ Rvalue :: Discriminant ( place) => {
521
+ match self . state . get_discr ( place. as_ref ( ) , self . visitor . map ) {
522
+ FlatSet :: Top => ( ) ,
523
+ FlatSet :: Elem ( value) => {
524
+ self . visitor . before_effect . insert ( ( location, * place) , value) ;
525
+ }
526
+ FlatSet :: Bottom => ( ) ,
527
+ }
528
+ }
529
+ _ => self . super_rvalue ( rvalue, location) ,
530
+ }
531
+ }
469
532
}
470
533
471
534
struct DummyMachine ;
0 commit comments