Skip to content

Commit ed7a206

Browse files
committed
Auto merge of #91844 - nnethercote:rm-ObligationCauseData-2, r=Mark-Simulacrum
Eliminate `ObligationCauseData` This makes `Obligation` two words bigger, but avoids allocating a lot of the time. I previously tried this in #73983 and it didn't help much, but local timings look more promising now. r? `@ghost`
2 parents e95e084 + f09b1fa commit ed7a206

File tree

22 files changed

+135
-139
lines changed

22 files changed

+135
-139
lines changed

compiler/rustc_infer/src/infer/error_reporting/mod.rs

+11-11
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
604604
exp_found: Option<ty::error::ExpectedFound<Ty<'tcx>>>,
605605
terr: &TypeError<'tcx>,
606606
) {
607-
match cause.code {
607+
match *cause.code() {
608608
ObligationCauseCode::Pattern { origin_expr: true, span: Some(span), root_ty } => {
609609
let ty = self.resolve_vars_if_possible(root_ty);
610610
if ty.is_suggestable() {
@@ -781,7 +781,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
781781
}
782782
_ => {
783783
if let ObligationCauseCode::BindingObligation(_, binding_span) =
784-
cause.code.peel_derives()
784+
cause.code().peel_derives()
785785
{
786786
if matches!(terr, TypeError::RegionsPlaceholderMismatch) {
787787
err.span_note(*binding_span, "the lifetime requirement is introduced here");
@@ -1729,10 +1729,10 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
17291729
}
17301730
_ => exp_found,
17311731
};
1732-
debug!("exp_found {:?} terr {:?} cause.code {:?}", exp_found, terr, cause.code);
1732+
debug!("exp_found {:?} terr {:?} cause.code {:?}", exp_found, terr, cause.code());
17331733
if let Some(exp_found) = exp_found {
17341734
let should_suggest_fixes = if let ObligationCauseCode::Pattern { root_ty, .. } =
1735-
&cause.code
1735+
cause.code()
17361736
{
17371737
// Skip if the root_ty of the pattern is not the same as the expected_ty.
17381738
// If these types aren't equal then we've probably peeled off a layer of arrays.
@@ -1827,15 +1827,15 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
18271827
exp_span, exp_found.expected, exp_found.found,
18281828
);
18291829

1830-
if let ObligationCauseCode::CompareImplMethodObligation { .. } = &cause.code {
1830+
if let ObligationCauseCode::CompareImplMethodObligation { .. } = cause.code() {
18311831
return;
18321832
}
18331833

18341834
match (
18351835
self.get_impl_future_output_ty(exp_found.expected),
18361836
self.get_impl_future_output_ty(exp_found.found),
18371837
) {
1838-
(Some(exp), Some(found)) if same_type_modulo_infer(exp, found) => match &cause.code {
1838+
(Some(exp), Some(found)) if same_type_modulo_infer(exp, found) => match cause.code() {
18391839
ObligationCauseCode::IfExpression(box IfExpressionCause { then, .. }) => {
18401840
diag.multipart_suggestion(
18411841
"consider `await`ing on both `Future`s",
@@ -1875,7 +1875,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
18751875
Applicability::MaybeIncorrect,
18761876
);
18771877
}
1878-
(Some(ty), _) if same_type_modulo_infer(ty, exp_found.found) => match cause.code {
1878+
(Some(ty), _) if same_type_modulo_infer(ty, exp_found.found) => match cause.code() {
18791879
ObligationCauseCode::Pattern { span: Some(span), .. }
18801880
| ObligationCauseCode::IfExpression(box IfExpressionCause { then: span, .. }) => {
18811881
diag.span_suggestion_verbose(
@@ -1927,7 +1927,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
19271927
.map(|field| (field.ident.name, field.ty(self.tcx, expected_substs)))
19281928
.find(|(_, ty)| same_type_modulo_infer(ty, exp_found.found))
19291929
{
1930-
if let ObligationCauseCode::Pattern { span: Some(span), .. } = cause.code {
1930+
if let ObligationCauseCode::Pattern { span: Some(span), .. } = *cause.code() {
19311931
if let Ok(snippet) = self.tcx.sess.source_map().span_to_snippet(span) {
19321932
let suggestion = if expected_def.is_struct() {
19331933
format!("{}.{}", snippet, name)
@@ -2064,7 +2064,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
20642064
}
20652065
}
20662066
if let MatchExpressionArm(box MatchExpressionArmCause { source, .. }) =
2067-
trace.cause.code
2067+
*trace.cause.code()
20682068
{
20692069
if let hir::MatchSource::TryDesugar = source {
20702070
if let Some((expected_ty, found_ty)) = self.values_str(trace.values) {
@@ -2659,7 +2659,7 @@ impl<'tcx> ObligationCauseExt<'tcx> for ObligationCause<'tcx> {
26592659
fn as_failure_code(&self, terr: &TypeError<'tcx>) -> FailureCode {
26602660
use self::FailureCode::*;
26612661
use crate::traits::ObligationCauseCode::*;
2662-
match self.code {
2662+
match self.code() {
26632663
CompareImplMethodObligation { .. } => Error0308("method not compatible with trait"),
26642664
CompareImplTypeObligation { .. } => Error0308("type not compatible with trait"),
26652665
MatchExpressionArm(box MatchExpressionArmCause { source, .. }) => {
@@ -2694,7 +2694,7 @@ impl<'tcx> ObligationCauseExt<'tcx> for ObligationCause<'tcx> {
26942694

26952695
fn as_requirement_str(&self) -> &'static str {
26962696
use crate::traits::ObligationCauseCode::*;
2697-
match self.code {
2697+
match self.code() {
26982698
CompareImplMethodObligation { .. } => "method type is compatible with trait",
26992699
CompareImplTypeObligation { .. } => "associated type is compatible with trait",
27002700
ExprAssignable => "expression is assignable",

compiler/rustc_infer/src/infer/error_reporting/nice_region_error/mismatched_static_lifetime.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,15 @@ impl<'a, 'tcx> NiceRegionError<'a, 'tcx> {
3131
};
3232
// If we added a "points at argument expression" obligation, we remove it here, we care
3333
// about the original obligation only.
34-
let code = match &cause.code {
34+
let code = match cause.code() {
3535
ObligationCauseCode::FunctionArgumentObligation { parent_code, .. } => &*parent_code,
36-
_ => &cause.code,
36+
_ => cause.code(),
3737
};
3838
let (parent, impl_def_id) = match code {
3939
ObligationCauseCode::MatchImpl(parent, impl_def_id) => (parent, impl_def_id),
4040
_ => return None,
4141
};
42-
let binding_span = match parent.code {
42+
let binding_span = match *parent.code() {
4343
ObligationCauseCode::BindingObligation(_def_id, binding_span) => binding_span,
4444
_ => return None,
4545
};

compiler/rustc_infer/src/infer/error_reporting/nice_region_error/placeholder_error.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ impl<'tcx> NiceRegionError<'_, 'tcx> {
208208
);
209209
let mut err = self.tcx().sess.struct_span_err(span, &msg);
210210

211-
let leading_ellipsis = if let ObligationCauseCode::ItemObligation(def_id) = cause.code {
211+
let leading_ellipsis = if let ObligationCauseCode::ItemObligation(def_id) = *cause.code() {
212212
err.span_label(span, "doesn't satisfy where-clause");
213213
err.span_label(
214214
self.tcx().def_span(def_id),

compiler/rustc_infer/src/infer/error_reporting/nice_region_error/static_impl_trait.rs

+6-6
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ impl<'a, 'tcx> NiceRegionError<'a, 'tcx> {
4242
sup_r,
4343
) if **sub_r == RegionKind::ReStatic => {
4444
// This is for an implicit `'static` requirement coming from `impl dyn Trait {}`.
45-
if let ObligationCauseCode::UnifyReceiver(ctxt) = &cause.code {
45+
if let ObligationCauseCode::UnifyReceiver(ctxt) = cause.code() {
4646
// This may have a closure and it would cause ICE
4747
// through `find_param_with_region` (#78262).
4848
let anon_reg_sup = tcx.is_suitable_region(sup_r)?;
@@ -184,7 +184,7 @@ impl<'a, 'tcx> NiceRegionError<'a, 'tcx> {
184184
}
185185
if let SubregionOrigin::Subtype(box TypeTrace { cause, .. }) = sub_origin {
186186
if let ObligationCauseCode::ReturnValue(hir_id)
187-
| ObligationCauseCode::BlockTailExpression(hir_id) = &cause.code
187+
| ObligationCauseCode::BlockTailExpression(hir_id) = cause.code()
188188
{
189189
let parent_id = tcx.hir().get_parent_item(*hir_id);
190190
if let Some(fn_decl) = tcx.hir().fn_decl_by_hir_id(parent_id) {
@@ -226,7 +226,7 @@ impl<'a, 'tcx> NiceRegionError<'a, 'tcx> {
226226

227227
let mut override_error_code = None;
228228
if let SubregionOrigin::Subtype(box TypeTrace { cause, .. }) = &sup_origin {
229-
if let ObligationCauseCode::UnifyReceiver(ctxt) = &cause.code {
229+
if let ObligationCauseCode::UnifyReceiver(ctxt) = cause.code() {
230230
// Handle case of `impl Foo for dyn Bar { fn qux(&self) {} }` introducing a
231231
// `'static` lifetime when called as a method on a binding: `bar.qux()`.
232232
if self.find_impl_on_dyn_trait(&mut err, param.param_ty, &ctxt) {
@@ -235,9 +235,9 @@ impl<'a, 'tcx> NiceRegionError<'a, 'tcx> {
235235
}
236236
}
237237
if let SubregionOrigin::Subtype(box TypeTrace { cause, .. }) = &sub_origin {
238-
let code = match &cause.code {
239-
ObligationCauseCode::MatchImpl(parent, ..) => &parent.code,
240-
_ => &cause.code,
238+
let code = match cause.code() {
239+
ObligationCauseCode::MatchImpl(parent, ..) => parent.code(),
240+
_ => cause.code(),
241241
};
242242
if let (ObligationCauseCode::ItemObligation(item_def_id), None) =
243243
(code, override_error_code)

compiler/rustc_infer/src/infer/error_reporting/nice_region_error/trait_impl_difference.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ impl<'a, 'tcx> NiceRegionError<'a, 'tcx> {
3636
ValuePairs::Types(sub_expected_found),
3737
ValuePairs::Types(sup_expected_found),
3838
CompareImplMethodObligation { trait_item_def_id, .. },
39-
) = (&sub_trace.values, &sup_trace.values, &sub_trace.cause.code)
39+
) = (&sub_trace.values, &sup_trace.values, sub_trace.cause.code())
4040
{
4141
if sup_expected_found == sub_expected_found {
4242
self.emit_err(

compiler/rustc_infer/src/infer/error_reporting/note.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -359,13 +359,13 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
359359
match placeholder_origin {
360360
infer::Subtype(box ref trace)
361361
if matches!(
362-
&trace.cause.code.peel_derives(),
362+
&trace.cause.code().peel_derives(),
363363
ObligationCauseCode::BindingObligation(..)
364364
) =>
365365
{
366366
// Hack to get around the borrow checker because trace.cause has an `Rc`.
367367
if let ObligationCauseCode::BindingObligation(_, span) =
368-
&trace.cause.code.peel_derives()
368+
&trace.cause.code().peel_derives()
369369
{
370370
let span = *span;
371371
let mut err = self.report_concrete_failure(placeholder_origin, sub, sup);

compiler/rustc_infer/src/infer/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1824,7 +1824,7 @@ impl<'tcx> SubregionOrigin<'tcx> {
18241824
where
18251825
F: FnOnce() -> Self,
18261826
{
1827-
match cause.code {
1827+
match *cause.code() {
18281828
traits::ObligationCauseCode::ReferenceOutlivesReferent(ref_type) => {
18291829
SubregionOrigin::ReferenceOutlivesReferent(ref_type, cause.span)
18301830
}

compiler/rustc_infer/src/infer/outlives/obligations.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ impl<'cx, 'tcx> InferCtxt<'cx, 'tcx> {
102102
infer::RelateParamBound(
103103
cause.span,
104104
sup_type,
105-
match cause.code.peel_derives() {
105+
match cause.code().peel_derives() {
106106
ObligationCauseCode::BindingObligation(_, span) => Some(*span),
107107
_ => None,
108108
},

compiler/rustc_infer/src/traits/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ impl TraitObligation<'_> {
8181

8282
// `PredicateObligation` is used a lot. Make sure it doesn't unintentionally get bigger.
8383
#[cfg(all(target_arch = "x86_64", target_pointer_width = "64"))]
84-
static_assert_size!(PredicateObligation<'_>, 32);
84+
static_assert_size!(PredicateObligation<'_>, 48);
8585

8686
pub type PredicateObligations<'tcx> = Vec<PredicateObligation<'tcx>>;
8787

compiler/rustc_middle/src/traits/mod.rs

+40-42
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,7 @@ use rustc_span::{Span, DUMMY_SP};
2323
use smallvec::SmallVec;
2424

2525
use std::borrow::Cow;
26-
use std::fmt;
2726
use std::hash::{Hash, Hasher};
28-
use std::ops::Deref;
2927

3028
pub use self::select::{EvaluationCache, EvaluationResult, OverflowError, SelectionCache};
3129

@@ -80,38 +78,14 @@ pub enum Reveal {
8078

8179
/// The reason why we incurred this obligation; used for error reporting.
8280
///
83-
/// As the happy path does not care about this struct, storing this on the heap
84-
/// ends up increasing performance.
81+
/// Non-misc `ObligationCauseCode`s are stored on the heap. This gives the
82+
/// best trade-off between keeping the type small (which makes copies cheaper)
83+
/// while not doing too many heap allocations.
8584
///
8685
/// We do not want to intern this as there are a lot of obligation causes which
8786
/// only live for a short period of time.
88-
#[derive(Clone, PartialEq, Eq, Hash, Lift)]
89-
pub struct ObligationCause<'tcx> {
90-
/// `None` for `ObligationCause::dummy`, `Some` otherwise.
91-
data: Option<Lrc<ObligationCauseData<'tcx>>>,
92-
}
93-
94-
const DUMMY_OBLIGATION_CAUSE_DATA: ObligationCauseData<'static> =
95-
ObligationCauseData { span: DUMMY_SP, body_id: hir::CRATE_HIR_ID, code: MiscObligation };
96-
97-
// Correctly format `ObligationCause::dummy`.
98-
impl<'tcx> fmt::Debug for ObligationCause<'tcx> {
99-
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100-
ObligationCauseData::fmt(self, f)
101-
}
102-
}
103-
104-
impl<'tcx> Deref for ObligationCause<'tcx> {
105-
type Target = ObligationCauseData<'tcx>;
106-
107-
#[inline(always)]
108-
fn deref(&self) -> &Self::Target {
109-
self.data.as_deref().unwrap_or(&DUMMY_OBLIGATION_CAUSE_DATA)
110-
}
111-
}
112-
11387
#[derive(Clone, Debug, PartialEq, Eq, Lift)]
114-
pub struct ObligationCauseData<'tcx> {
88+
pub struct ObligationCause<'tcx> {
11589
pub span: Span,
11690

11791
/// The ID of the fn body that triggered this obligation. This is
@@ -122,46 +96,58 @@ pub struct ObligationCauseData<'tcx> {
12296
/// information.
12397
pub body_id: hir::HirId,
12498

125-
pub code: ObligationCauseCode<'tcx>,
99+
/// `None` for `MISC_OBLIGATION_CAUSE_CODE` (a common case, occurs ~60% of
100+
/// the time). `Some` otherwise.
101+
code: Option<Lrc<ObligationCauseCode<'tcx>>>,
126102
}
127103

128-
impl Hash for ObligationCauseData<'_> {
104+
// This custom hash function speeds up hashing for `Obligation` deduplication
105+
// greatly by skipping the `code` field, which can be large and complex. That
106+
// shouldn't affect hash quality much since there are several other fields in
107+
// `Obligation` which should be unique enough, especially the predicate itself
108+
// which is hashed as an interned pointer. See #90996.
109+
impl Hash for ObligationCause<'_> {
129110
fn hash<H: Hasher>(&self, state: &mut H) {
130111
self.body_id.hash(state);
131112
self.span.hash(state);
132-
std::mem::discriminant(&self.code).hash(state);
133113
}
134114
}
135115

116+
const MISC_OBLIGATION_CAUSE_CODE: ObligationCauseCode<'static> = MiscObligation;
117+
136118
impl<'tcx> ObligationCause<'tcx> {
137119
#[inline]
138120
pub fn new(
139121
span: Span,
140122
body_id: hir::HirId,
141123
code: ObligationCauseCode<'tcx>,
142124
) -> ObligationCause<'tcx> {
143-
ObligationCause { data: Some(Lrc::new(ObligationCauseData { span, body_id, code })) }
125+
ObligationCause {
126+
span,
127+
body_id,
128+
code: if code == MISC_OBLIGATION_CAUSE_CODE { None } else { Some(Lrc::new(code)) },
129+
}
144130
}
145131

146132
pub fn misc(span: Span, body_id: hir::HirId) -> ObligationCause<'tcx> {
147133
ObligationCause::new(span, body_id, MiscObligation)
148134
}
149135

150-
pub fn dummy_with_span(span: Span) -> ObligationCause<'tcx> {
151-
ObligationCause::new(span, hir::CRATE_HIR_ID, MiscObligation)
152-
}
153-
154136
#[inline(always)]
155137
pub fn dummy() -> ObligationCause<'tcx> {
156-
ObligationCause { data: None }
138+
ObligationCause { span: DUMMY_SP, body_id: hir::CRATE_HIR_ID, code: None }
139+
}
140+
141+
pub fn dummy_with_span(span: Span) -> ObligationCause<'tcx> {
142+
ObligationCause { span, body_id: hir::CRATE_HIR_ID, code: None }
157143
}
158144

159-
pub fn make_mut(&mut self) -> &mut ObligationCauseData<'tcx> {
160-
Lrc::make_mut(self.data.get_or_insert_with(|| Lrc::new(DUMMY_OBLIGATION_CAUSE_DATA)))
145+
pub fn make_mut_code(&mut self) -> &mut ObligationCauseCode<'tcx> {
146+
Lrc::make_mut(self.code.get_or_insert_with(|| Lrc::new(MISC_OBLIGATION_CAUSE_CODE)))
161147
}
162148

163149
pub fn span(&self, tcx: TyCtxt<'tcx>) -> Span {
164-
match self.code {
150+
match *self.code() {
165151
ObligationCauseCode::CompareImplMethodObligation { .. }
166152
| ObligationCauseCode::MainFunctionType
167153
| ObligationCauseCode::StartFunctionType => {
@@ -174,6 +160,18 @@ impl<'tcx> ObligationCause<'tcx> {
174160
_ => self.span,
175161
}
176162
}
163+
164+
#[inline]
165+
pub fn code(&self) -> &ObligationCauseCode<'tcx> {
166+
self.code.as_deref().unwrap_or(&MISC_OBLIGATION_CAUSE_CODE)
167+
}
168+
169+
pub fn clone_code(&self) -> Lrc<ObligationCauseCode<'tcx>> {
170+
match &self.code {
171+
Some(code) => code.clone(),
172+
None => Lrc::new(MISC_OBLIGATION_CAUSE_CODE),
173+
}
174+
}
177175
}
178176

179177
#[derive(Clone, Debug, PartialEq, Eq, Hash, Lift)]

compiler/rustc_middle/src/ty/error.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ impl<T> Trait<T> for X {
519519
proj_ty,
520520
values,
521521
body_owner_def_id,
522-
&cause.code,
522+
cause.code(),
523523
);
524524
}
525525
(_, ty::Projection(proj_ty)) => {

0 commit comments

Comments
 (0)