Skip to content

Commit f8b8ae3

Browse files
committed
Auto merge of #56810 - sinkuu:build_match, r=<try>
Improve MIR match generation for ranges Improves MIR match generation to rule out ranges/values distinct from the range that has been tested. e.g., for this code: ```rust match x { 0..=5 if b => 0, 6..=10 => 1, _ => 2, } ``` MIR (before): ```rust bb0: { ...; _4 = Le(const 0i32, _1); switchInt(move _4) -> [false: bb6, otherwise: bb5]; } bb1: { _3 = const 0i32; goto -> bb8; } bb2: { _6 = _2; switchInt(move _6) -> [false: bb6, otherwise: bb1]; } // If `!b`, jumps to test if `6 <= x <= 10`. bb3: { _3 = const 1i32; goto -> bb8; } bb4: { _3 = const 2i32; goto -> bb8; } bb5: { _5 = Le(_1, const 5i32); switchInt(move _5) -> [false: bb6, otherwise: bb2]; } bb6: { _7 = Le(const 6i32, _1); switchInt(move _7) -> [false: bb4, otherwise: bb7]; } bb7: { _8 = Le(_1, const 10i32); switchInt(move _8) -> [false: bb4, otherwise: bb3]; } ``` MIR (after): ```rust bb0: { ...; _4 = Le(const 0i32, _1); switchInt(move _4) -> [false: bb5, otherwise: bb6]; } bb1: { _3 = const 0i32; goto -> bb8; } bb2: { _6 = _2; switchInt(move _6) -> [false: bb4, otherwise: bb1]; } // If `!b`, jumps to `_ =>` arm. bb3: { _3 = const 1i32; goto -> bb8; } bb4: { _3 = const 2i32; goto -> bb8; } bb5: { _7 = Le(const 6i32, _1); switchInt(move _7) -> [false: bb4, otherwise: bb7]; } bb6: { _5 = Le(_1, const 5i32); switchInt(move _5) -> [false: bb5, otherwise: bb2]; } bb7: { _8 = Le(_1, const 10i32); switchInt(move _8) -> [false: bb4, otherwise: bb3]; } ``` cc #29623
2 parents 63f8e6e + d66a55e commit f8b8ae3

File tree

8 files changed

+339
-69
lines changed

8 files changed

+339
-69
lines changed

src/librustc_mir/build/matches/mod.rs

+6-12
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ use build::{BlockAnd, BlockAndExtension, Builder};
1919
use build::{GuardFrame, GuardFrameLocal, LocalsForNode};
2020
use hair::*;
2121
use hair::pattern::PatternTypeProjections;
22-
use rustc::hir;
2322
use rustc::mir::*;
2423
use rustc::ty::{self, Ty};
2524
use rustc::ty::layout::VariantIdx;
@@ -100,7 +99,7 @@ impl<'a, 'gcx, 'tcx> Builder<'a, 'gcx, 'tcx> {
10099
.collect();
101100

102101
// create binding start block for link them by false edges
103-
let candidate_count = arms.iter().fold(0, |ac, c| ac + c.patterns.len());
102+
let candidate_count = arms.iter().map(|c| c.patterns.len()).sum::<usize>();
104103
let pre_binding_blocks: Vec<_> = (0..=candidate_count)
105104
.map(|_| self.cfg.start_new_block())
106105
.collect();
@@ -337,7 +336,7 @@ impl<'a, 'gcx, 'tcx> Builder<'a, 'gcx, 'tcx> {
337336

338337
pub fn place_into_pattern(
339338
&mut self,
340-
mut block: BasicBlock,
339+
block: BasicBlock,
341340
irrefutable_pat: Pattern<'tcx>,
342341
initializer: &Place<'tcx>,
343342
set_match_place: bool,
@@ -359,7 +358,7 @@ impl<'a, 'gcx, 'tcx> Builder<'a, 'gcx, 'tcx> {
359358

360359
// Simplify the candidate. Since the pattern is irrefutable, this should
361360
// always convert all match-pairs into bindings.
362-
unpack!(block = self.simplify_candidate(block, &mut candidate));
361+
self.simplify_candidate(&mut candidate);
363362

364363
if !candidate.match_pairs.is_empty() {
365364
span_bug!(
@@ -681,12 +680,7 @@ enum TestKind<'tcx> {
681680
},
682681

683682
// test whether the value falls within an inclusive or exclusive range
684-
Range {
685-
lo: &'tcx ty::Const<'tcx>,
686-
hi: &'tcx ty::Const<'tcx>,
687-
ty: Ty<'tcx>,
688-
end: hir::RangeEnd,
689-
},
683+
Range(PatternRange<'tcx>),
690684

691685
// test length of the slice is equal to len
692686
Len {
@@ -745,7 +739,7 @@ impl<'a, 'gcx, 'tcx> Builder<'a, 'gcx, 'tcx> {
745739
// complete, all the match pairs which remain require some
746740
// form of test, whether it be a switch or pattern comparison.
747741
for candidate in &mut candidates {
748-
unpack!(block = self.simplify_candidate(block, candidate));
742+
self.simplify_candidate(candidate);
749743
}
750744

751745
// The candidates are sorted by priority. Check to see
@@ -1035,7 +1029,7 @@ impl<'a, 'gcx, 'tcx> Builder<'a, 'gcx, 'tcx> {
10351029
test, match_pair
10361030
);
10371031
let target_blocks = self.perform_test(block, &match_pair.place, &test);
1038-
let mut target_candidates: Vec<_> = (0..target_blocks.len()).map(|_| vec![]).collect();
1032+
let mut target_candidates = vec![vec![]; target_blocks.len()];
10391033

10401034
// Sort the candidates into the appropriate vector in
10411035
// `target_candidates`. Note that at some point we may

src/librustc_mir/build/matches/simplify.rs

+9-11
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,9 @@
2222
//! sort of test: for example, testing which variant an enum is, or
2323
//! testing a value against a constant.
2424
25-
use build::{BlockAnd, BlockAndExtension, Builder};
25+
use build::Builder;
2626
use build::matches::{Ascription, Binding, MatchPair, Candidate};
2727
use hair::*;
28-
use rustc::mir::*;
2928
use rustc::ty;
3029
use rustc::ty::layout::{Integer, IntegerExt, Size};
3130
use syntax::attr::{SignedInt, UnsignedInt};
@@ -35,24 +34,23 @@ use std::mem;
3534

3635
impl<'a, 'gcx, 'tcx> Builder<'a, 'gcx, 'tcx> {
3736
pub fn simplify_candidate<'pat>(&mut self,
38-
block: BasicBlock,
39-
candidate: &mut Candidate<'pat, 'tcx>)
40-
-> BlockAnd<()> {
37+
candidate: &mut Candidate<'pat, 'tcx>) {
4138
// repeatedly simplify match pairs until fixed point is reached
4239
loop {
4340
let match_pairs = mem::replace(&mut candidate.match_pairs, vec![]);
44-
let mut progress = match_pairs.len(); // count how many were simplified
41+
let mut changed = false;
4542
for match_pair in match_pairs {
4643
match self.simplify_match_pair(match_pair, candidate) {
47-
Ok(()) => {}
44+
Ok(()) => {
45+
changed = true;
46+
}
4847
Err(match_pair) => {
4948
candidate.match_pairs.push(match_pair);
50-
progress -= 1; // this one was not simplified
5149
}
5250
}
5351
}
54-
if progress == 0 {
55-
return block.unit(); // if we were not able to simplify any, done.
52+
if !changed {
53+
return; // if we were not able to simplify any, done.
5654
}
5755
}
5856
}
@@ -109,7 +107,7 @@ impl<'a, 'gcx, 'tcx> Builder<'a, 'gcx, 'tcx> {
109107
Err(match_pair)
110108
}
111109

112-
PatternKind::Range { lo, hi, ty, end } => {
110+
PatternKind::Range(PatternRange { lo, hi, ty, end }) => {
113111
let range = match ty.sty {
114112
ty::Char => {
115113
Some(('\u{0000}' as u128, '\u{10FFFF}' as u128, Size::from_bits(32)))

src/librustc_mir/build/matches/test.rs

+128-24
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
use build::Builder;
1919
use build::matches::{Candidate, MatchPair, Test, TestKind};
2020
use hair::*;
21+
use hair::pattern::compare_const_vals;
2122
use rustc_data_structures::bit_set::BitSet;
2223
use rustc_data_structures::fx::FxHashMap;
2324
use rustc::ty::{self, Ty};
@@ -71,16 +72,11 @@ impl<'a, 'gcx, 'tcx> Builder<'a, 'gcx, 'tcx> {
7172
}
7273
}
7374

74-
PatternKind::Range { lo, hi, ty, end } => {
75-
assert!(ty == match_pair.pattern.ty);
75+
PatternKind::Range(range) => {
76+
assert!(range.ty == match_pair.pattern.ty);
7677
Test {
7778
span: match_pair.pattern.span,
78-
kind: TestKind::Range {
79-
lo,
80-
hi,
81-
ty,
82-
end,
83-
},
79+
kind: TestKind::Range(range),
8480
}
8581
}
8682

@@ -136,7 +132,11 @@ impl<'a, 'gcx, 'tcx> Builder<'a, 'gcx, 'tcx> {
136132
PatternKind::Variant { .. } => {
137133
panic!("you should have called add_variants_to_switch instead!");
138134
}
139-
PatternKind::Range { .. } |
135+
PatternKind::Range(range) => {
136+
// Check that none of the switch values are in the range.
137+
self.values_not_contained_in_range(range, indices)
138+
.unwrap_or(false)
139+
}
140140
PatternKind::Slice { .. } |
141141
PatternKind::Array { .. } |
142142
PatternKind::Wild |
@@ -200,20 +200,18 @@ impl<'a, 'gcx, 'tcx> Builder<'a, 'gcx, 'tcx> {
200200
for (idx, discr) in adt_def.discriminants(tcx) {
201201
target_blocks.push(if variants.contains(idx) {
202202
values.push(discr.val);
203-
targets.push(self.cfg.start_new_block());
204-
*targets.last().unwrap()
203+
let block = self.cfg.start_new_block();
204+
targets.push(block);
205+
block
205206
} else {
206-
if otherwise_block.is_none() {
207-
otherwise_block = Some(self.cfg.start_new_block());
208-
}
209-
otherwise_block.unwrap()
207+
*otherwise_block
208+
.get_or_insert_with(|| self.cfg.start_new_block())
210209
});
211210
}
212-
if let Some(otherwise_block) = otherwise_block {
213-
targets.push(otherwise_block);
214-
} else {
215-
targets.push(self.unreachable_block());
216-
}
211+
targets.push(
212+
otherwise_block
213+
.unwrap_or_else(|| self.unreachable_block()),
214+
);
217215
debug!("num_enum_variants: {}, tested variants: {:?}, variants: {:?}",
218216
num_enum_variants, values, variants);
219217
let discr_ty = adt_def.repr.discr_type().to_ty(tcx);
@@ -378,7 +376,7 @@ impl<'a, 'gcx, 'tcx> Builder<'a, 'gcx, 'tcx> {
378376
}
379377
}
380378

381-
TestKind::Range { ref lo, ref hi, ty, ref end } => {
379+
TestKind::Range(PatternRange { ref lo, ref hi, ty, ref end }) => {
382380
// Test `val` by computing `lo <= val && val <= hi`, using primitive comparisons.
383381
let lo = self.literal_operand(test.span, ty.clone(), lo.clone());
384382
let hi = self.literal_operand(test.span, ty.clone(), hi.clone());
@@ -490,8 +488,7 @@ impl<'a, 'gcx, 'tcx> Builder<'a, 'gcx, 'tcx> {
490488
// away.)
491489
let tested_match_pair = candidate.match_pairs.iter()
492490
.enumerate()
493-
.filter(|&(_, mp)| mp.place == *test_place)
494-
.next();
491+
.find(|&(_, mp)| mp.place == *test_place);
495492
let (match_pair_index, match_pair) = match tested_match_pair {
496493
Some(pair) => pair,
497494
None => {
@@ -532,6 +529,24 @@ impl<'a, 'gcx, 'tcx> Builder<'a, 'gcx, 'tcx> {
532529
resulting_candidates[index].push(new_candidate);
533530
true
534531
}
532+
533+
(&TestKind::SwitchInt { switch_ty: _, ref options, ref indices },
534+
&PatternKind::Range(range)) => {
535+
let not_contained = self
536+
.values_not_contained_in_range(range, indices)
537+
.unwrap_or(false);
538+
539+
if not_contained {
540+
// No switch values are contained in the pattern range,
541+
// so the pattern can be matched only if this test fails.
542+
let otherwise = options.len();
543+
resulting_candidates[otherwise].push(candidate.clone());
544+
true
545+
} else {
546+
false
547+
}
548+
}
549+
535550
(&TestKind::SwitchInt { .. }, _) => false,
536551

537552

@@ -610,8 +625,63 @@ impl<'a, 'gcx, 'tcx> Builder<'a, 'gcx, 'tcx> {
610625
}
611626
}
612627

628+
(&TestKind::Range(test),
629+
&PatternKind::Range(pat)) => {
630+
if test == pat {
631+
resulting_candidates[0]
632+
.push(self.candidate_without_match_pair(
633+
match_pair_index,
634+
candidate,
635+
));
636+
return true;
637+
}
638+
639+
let no_overlap = (|| {
640+
use std::cmp::Ordering::*;
641+
use rustc::hir::RangeEnd::*;
642+
643+
let param_env = ty::ParamEnv::empty().and(test.ty);
644+
let tcx = self.hir.tcx();
645+
646+
let lo = compare_const_vals(tcx, test.lo, pat.hi, param_env)?;
647+
let hi = compare_const_vals(tcx, test.hi, pat.lo, param_env)?;
648+
649+
match (test.end, pat.end, lo, hi) {
650+
// pat < test
651+
(_, _, Greater, _) |
652+
(_, Excluded, Equal, _) |
653+
// pat > test
654+
(_, _, _, Less) |
655+
(Excluded, _, _, Equal) => Some(true),
656+
_ => Some(false),
657+
}
658+
})();
659+
660+
if no_overlap == Some(true) {
661+
// Testing range does not overlap with pattern range,
662+
// so the pattern can be matched only if this test fails.
663+
resulting_candidates[1].push(candidate.clone());
664+
true
665+
} else {
666+
false
667+
}
668+
}
669+
670+
(&TestKind::Range(range), &PatternKind::Constant { ref value }) => {
671+
if self.const_range_contains(range, value) == Some(false) {
672+
// `value` is not contained in the testing range,
673+
// so `value` can be matched only if this test fails.
674+
resulting_candidates[1].push(candidate.clone());
675+
true
676+
} else {
677+
false
678+
}
679+
}
680+
681+
(&TestKind::Range { .. }, _) => false,
682+
683+
613684
(&TestKind::Eq { .. }, _) |
614-
(&TestKind::Range { .. }, _) |
615685
(&TestKind::Len { .. }, _) => {
616686
// These are all binary tests.
617687
//
@@ -722,6 +792,40 @@ impl<'a, 'gcx, 'tcx> Builder<'a, 'gcx, 'tcx> {
722792
"simplifyable pattern found: {:?}",
723793
match_pair.pattern)
724794
}
795+
796+
fn const_range_contains(
797+
&self,
798+
range: PatternRange<'tcx>,
799+
value: &'tcx ty::Const<'tcx>,
800+
) -> Option<bool> {
801+
use std::cmp::Ordering::*;
802+
803+
let param_env = ty::ParamEnv::empty().and(range.ty);
804+
let tcx = self.hir.tcx();
805+
806+
let a = compare_const_vals(tcx, range.lo, value, param_env)?;
807+
let b = compare_const_vals(tcx, value, range.hi, param_env)?;
808+
809+
match (b, range.end) {
810+
(Less, _) |
811+
(Equal, RangeEnd::Included) if a != Greater => Some(true),
812+
_ => Some(false),
813+
}
814+
}
815+
816+
fn values_not_contained_in_range(
817+
&self,
818+
range: PatternRange<'tcx>,
819+
indices: &FxHashMap<&'tcx ty::Const<'tcx>, usize>,
820+
) -> Option<bool> {
821+
for val in indices.keys() {
822+
if self.const_range_contains(range, val)? {
823+
return Some(false);
824+
}
825+
}
826+
827+
Some(true)
828+
}
725829
}
726830

727831
fn is_switch_ty<'tcx>(ty: Ty<'tcx>) -> bool {

src/librustc_mir/hair/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ pub mod cx;
2929
mod constant;
3030

3131
pub mod pattern;
32-
pub use self::pattern::{BindingMode, Pattern, PatternKind, FieldPattern};
32+
pub use self::pattern::{BindingMode, Pattern, PatternKind, PatternRange, FieldPattern};
3333
pub(crate) use self::pattern::{PatternTypeProjection, PatternTypeProjections};
3434

3535
mod util;

0 commit comments

Comments
 (0)