Skip to content

Commit 349b3b3

Browse files
committed
Auto merge of #79209 - spastorino:trait-inheritance-self, r=nikomatsakis
Allow Trait inheritance with cycles on associated types Fixes #35237 r? `@nikomatsakis` cc `@estebank`
2 parents b776d1c + ada7c1f commit 349b3b3

28 files changed

+498
-151
lines changed

compiler/rustc_infer/src/traits/util.rs

+33-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
use smallvec::smallvec;
22

33
use crate::traits::{Obligation, ObligationCause, PredicateObligation};
4-
use rustc_data_structures::fx::FxHashSet;
4+
use rustc_data_structures::fx::{FxHashSet, FxIndexSet};
55
use rustc_middle::ty::outlives::Component;
66
use rustc_middle::ty::{self, ToPredicate, TyCtxt, WithConstness};
7+
use rustc_span::symbol::Ident;
78

89
pub fn anonymize_predicate<'tcx>(
910
tcx: TyCtxt<'tcx>,
@@ -287,6 +288,37 @@ pub fn transitive_bounds<'tcx>(
287288
elaborate_trait_refs(tcx, bounds).filter_to_traits()
288289
}
289290

291+
/// A specialized variant of `elaborate_trait_refs` that only elaborates trait references that may
292+
/// define the given associated type `assoc_name`. It uses the
293+
/// `super_predicates_that_define_assoc_type` query to avoid enumerating super-predicates that
294+
/// aren't related to `assoc_item`. This is used when resolving types like `Self::Item` or
295+
/// `T::Item` and helps to avoid cycle errors (see e.g. #35237).
296+
pub fn transitive_bounds_that_define_assoc_type<'tcx>(
297+
tcx: TyCtxt<'tcx>,
298+
bounds: impl Iterator<Item = ty::PolyTraitRef<'tcx>>,
299+
assoc_name: Ident,
300+
) -> FxIndexSet<ty::PolyTraitRef<'tcx>> {
301+
let mut stack: Vec<_> = bounds.collect();
302+
let mut trait_refs = FxIndexSet::default();
303+
304+
while let Some(trait_ref) = stack.pop() {
305+
if trait_refs.insert(trait_ref) {
306+
let super_predicates =
307+
tcx.super_predicates_that_define_assoc_type((trait_ref.def_id(), Some(assoc_name)));
308+
for (super_predicate, _) in super_predicates.predicates {
309+
let bound_predicate = super_predicate.bound_atom();
310+
let subst_predicate = super_predicate
311+
.subst_supertrait(tcx, &bound_predicate.rebind(trait_ref.skip_binder()));
312+
if let Some(binder) = subst_predicate.to_opt_poly_trait_ref() {
313+
stack.push(binder.value);
314+
}
315+
}
316+
}
317+
}
318+
319+
trait_refs
320+
}
321+
290322
///////////////////////////////////////////////////////////////////////////
291323
// Other
292324
///////////////////////////////////////////////////////////////////////////

compiler/rustc_middle/src/query/mod.rs

+13-2
Original file line numberDiff line numberDiff line change
@@ -433,12 +433,23 @@ rustc_queries! {
433433
/// full predicates are available (note that supertraits have
434434
/// additional acyclicity requirements).
435435
query super_predicates_of(key: DefId) -> ty::GenericPredicates<'tcx> {
436-
desc { |tcx| "computing the supertraits of `{}`", tcx.def_path_str(key) }
436+
desc { |tcx| "computing the super predicates of `{}`", tcx.def_path_str(key) }
437+
}
438+
439+
/// The `Option<Ident>` is the name of an associated type. If it is `None`, then this query
440+
/// returns the full set of predicates. If `Some<Ident>`, then the query returns only the
441+
/// subset of super-predicates that reference traits that define the given associated type.
442+
/// This is used to avoid cycles in resolving types like `T::Item`.
443+
query super_predicates_that_define_assoc_type(key: (DefId, Option<rustc_span::symbol::Ident>)) -> ty::GenericPredicates<'tcx> {
444+
desc { |tcx| "computing the super traits of `{}`{}",
445+
tcx.def_path_str(key.0),
446+
if let Some(assoc_name) = key.1 { format!(" with associated type name `{}`", assoc_name) } else { "".to_string() },
447+
}
437448
}
438449

439450
/// To avoid cycles within the predicates of a single item we compute
440451
/// per-type-parameter predicates for resolving `T::AssocTy`.
441-
query type_param_predicates(key: (DefId, LocalDefId)) -> ty::GenericPredicates<'tcx> {
452+
query type_param_predicates(key: (DefId, LocalDefId, rustc_span::symbol::Ident)) -> ty::GenericPredicates<'tcx> {
442453
desc { |tcx| "computing the bounds for type parameter `{}`", {
443454
let id = tcx.hir().local_def_id_to_hir_id(key.1);
444455
tcx.hir().ty_param_name(id)

compiler/rustc_middle/src/ty/context.rs

+37-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ use rustc_session::config::{BorrowckMode, CrateType, OutputFilenames};
5151
use rustc_session::lint::{Level, Lint};
5252
use rustc_session::Session;
5353
use rustc_span::source_map::MultiSpan;
54-
use rustc_span::symbol::{kw, sym, Symbol};
54+
use rustc_span::symbol::{kw, sym, Ident, Symbol};
5555
use rustc_span::{Span, DUMMY_SP};
5656
use rustc_target::abi::{Layout, TargetDataLayout, VariantIdx};
5757
use rustc_target::spec::abi;
@@ -2085,6 +2085,42 @@ impl<'tcx> TyCtxt<'tcx> {
20852085
self.mk_fn_ptr(sig.map_bound(|sig| ty::FnSig { unsafety: hir::Unsafety::Unsafe, ..sig }))
20862086
}
20872087

2088+
/// Given the def_id of a Trait `trait_def_id` and the name of an associated item `assoc_name`
2089+
/// returns true if the `trait_def_id` defines an associated item of name `assoc_name`.
2090+
pub fn trait_may_define_assoc_type(self, trait_def_id: DefId, assoc_name: Ident) -> bool {
2091+
self.super_traits_of(trait_def_id).any(|trait_did| {
2092+
self.associated_items(trait_did)
2093+
.find_by_name_and_kind(self, assoc_name, ty::AssocKind::Type, trait_did)
2094+
.is_some()
2095+
})
2096+
}
2097+
2098+
/// Computes the def-ids of the transitive super-traits of `trait_def_id`. This (intentionally)
2099+
/// does not compute the full elaborated super-predicates but just the set of def-ids. It is used
2100+
/// to identify which traits may define a given associated type to help avoid cycle errors.
2101+
/// Returns a `DefId` iterator.
2102+
fn super_traits_of(self, trait_def_id: DefId) -> impl Iterator<Item = DefId> + 'tcx {
2103+
let mut set = FxHashSet::default();
2104+
let mut stack = vec![trait_def_id];
2105+
2106+
set.insert(trait_def_id);
2107+
2108+
iter::from_fn(move || -> Option<DefId> {
2109+
let trait_did = stack.pop()?;
2110+
let generic_predicates = self.super_predicates_of(trait_did);
2111+
2112+
for (predicate, _) in generic_predicates.predicates {
2113+
if let ty::PredicateAtom::Trait(data, _) = predicate.skip_binders() {
2114+
if set.insert(data.def_id()) {
2115+
stack.push(data.def_id());
2116+
}
2117+
}
2118+
}
2119+
2120+
Some(trait_did)
2121+
})
2122+
}
2123+
20882124
/// Given a closure signature, returns an equivalent fn signature. Detuples
20892125
/// and so forth -- so e.g., if we have a sig with `Fn<(u32, i32)>` then
20902126
/// you would get a `fn(u32, i32)`.

compiler/rustc_middle/src/ty/query/keys.rs

+23-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use crate::ty::subst::{GenericArg, SubstsRef};
77
use crate::ty::{self, Ty, TyCtxt};
88
use rustc_hir::def_id::{CrateNum, DefId, LocalDefId, LOCAL_CRATE};
99
use rustc_query_system::query::DefaultCacheSelector;
10-
use rustc_span::symbol::Symbol;
10+
use rustc_span::symbol::{Ident, Symbol};
1111
use rustc_span::{Span, DUMMY_SP};
1212

1313
/// The `Key` trait controls what types can legally be used as the key
@@ -149,6 +149,28 @@ impl Key for (LocalDefId, DefId) {
149149
}
150150
}
151151

152+
impl Key for (DefId, Option<Ident>) {
153+
type CacheSelector = DefaultCacheSelector;
154+
155+
fn query_crate(&self) -> CrateNum {
156+
self.0.krate
157+
}
158+
fn default_span(&self, tcx: TyCtxt<'_>) -> Span {
159+
tcx.def_span(self.0)
160+
}
161+
}
162+
163+
impl Key for (DefId, LocalDefId, Ident) {
164+
type CacheSelector = DefaultCacheSelector;
165+
166+
fn query_crate(&self) -> CrateNum {
167+
self.0.krate
168+
}
169+
fn default_span(&self, tcx: TyCtxt<'_>) -> Span {
170+
self.1.default_span(tcx)
171+
}
172+
}
173+
152174
impl Key for (CrateNum, DefId) {
153175
type CacheSelector = DefaultCacheSelector;
154176

compiler/rustc_trait_selection/src/traits/mod.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ pub use self::util::{
6565
get_vtable_index_of_object_method, impl_item_is_final, predicate_for_trait_def, upcast_choices,
6666
};
6767
pub use self::util::{
68-
supertrait_def_ids, supertraits, transitive_bounds, SupertraitDefIds, Supertraits,
68+
supertrait_def_ids, supertraits, transitive_bounds, transitive_bounds_that_define_assoc_type,
69+
SupertraitDefIds, Supertraits,
6970
};
7071

7172
pub use self::chalk_fulfill::FulfillmentContext as ChalkFulfillmentContext;

compiler/rustc_typeck/src/astconv/mod.rs

+57-11
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,10 @@ pub trait AstConv<'tcx> {
4949

5050
fn default_constness_for_trait_bounds(&self) -> Constness;
5151

52-
/// Returns predicates in scope of the form `X: Foo`, where `X` is
53-
/// a type parameter `X` with the given id `def_id`. This is a
54-
/// subset of the full set of predicates.
52+
/// Returns predicates in scope of the form `X: Foo<T>`, where `X`
53+
/// is a type parameter `X` with the given id `def_id` and T
54+
/// matches `assoc_name`. This is a subset of the full set of
55+
/// predicates.
5556
///
5657
/// This is used for one specific purpose: resolving "short-hand"
5758
/// associated type references like `T::Item`. In principle, we
@@ -60,7 +61,12 @@ pub trait AstConv<'tcx> {
6061
/// but this can lead to cycle errors. The problem is that we have
6162
/// to do this resolution *in order to create the predicates in
6263
/// the first place*. Hence, we have this "special pass".
63-
fn get_type_parameter_bounds(&self, span: Span, def_id: DefId) -> ty::GenericPredicates<'tcx>;
64+
fn get_type_parameter_bounds(
65+
&self,
66+
span: Span,
67+
def_id: DefId,
68+
assoc_name: Ident,
69+
) -> ty::GenericPredicates<'tcx>;
6470

6571
/// Returns the lifetime to use when a lifetime is omitted (and not elided).
6672
fn re_infer(&self, param: Option<&ty::GenericParamDef>, span: Span)
@@ -762,7 +768,7 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
762768
}
763769

764770
// Returns `true` if a bounds list includes `?Sized`.
765-
pub fn is_unsized(&self, ast_bounds: &[hir::GenericBound<'_>], span: Span) -> bool {
771+
pub fn is_unsized(&self, ast_bounds: &[&hir::GenericBound<'_>], span: Span) -> bool {
766772
let tcx = self.tcx();
767773

768774
// Try to find an unbound in bounds.
@@ -820,7 +826,7 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
820826
fn add_bounds(
821827
&self,
822828
param_ty: Ty<'tcx>,
823-
ast_bounds: &[hir::GenericBound<'_>],
829+
ast_bounds: &[&hir::GenericBound<'_>],
824830
bounds: &mut Bounds<'tcx>,
825831
) {
826832
let constness = self.default_constness_for_trait_bounds();
@@ -835,7 +841,7 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
835841
hir::GenericBound::Trait(_, hir::TraitBoundModifier::Maybe) => {}
836842
hir::GenericBound::LangItemTrait(lang_item, span, hir_id, args) => self
837843
.instantiate_lang_item_trait_ref(
838-
lang_item, span, hir_id, args, param_ty, bounds,
844+
*lang_item, *span, *hir_id, args, param_ty, bounds,
839845
),
840846
hir::GenericBound::Outlives(ref l) => {
841847
bounds.region_bounds.push((self.ast_region_to_region(l, None), l.span))
@@ -866,6 +872,42 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
866872
ast_bounds: &[hir::GenericBound<'_>],
867873
sized_by_default: SizedByDefault,
868874
span: Span,
875+
) -> Bounds<'tcx> {
876+
let ast_bounds: Vec<_> = ast_bounds.iter().collect();
877+
self.compute_bounds_inner(param_ty, &ast_bounds, sized_by_default, span)
878+
}
879+
880+
/// Convert the bounds in `ast_bounds` that refer to traits which define an associated type
881+
/// named `assoc_name` into ty::Bounds. Ignore the rest.
882+
pub fn compute_bounds_that_match_assoc_type(
883+
&self,
884+
param_ty: Ty<'tcx>,
885+
ast_bounds: &[hir::GenericBound<'_>],
886+
sized_by_default: SizedByDefault,
887+
span: Span,
888+
assoc_name: Ident,
889+
) -> Bounds<'tcx> {
890+
let mut result = Vec::new();
891+
892+
for ast_bound in ast_bounds {
893+
if let Some(trait_ref) = ast_bound.trait_ref() {
894+
if let Some(trait_did) = trait_ref.trait_def_id() {
895+
if self.tcx().trait_may_define_assoc_type(trait_did, assoc_name) {
896+
result.push(ast_bound);
897+
}
898+
}
899+
}
900+
}
901+
902+
self.compute_bounds_inner(param_ty, &result, sized_by_default, span)
903+
}
904+
905+
fn compute_bounds_inner(
906+
&self,
907+
param_ty: Ty<'tcx>,
908+
ast_bounds: &[&hir::GenericBound<'_>],
909+
sized_by_default: SizedByDefault,
910+
span: Span,
869911
) -> Bounds<'tcx> {
870912
let mut bounds = Bounds::default();
871913

@@ -1035,7 +1077,8 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
10351077
// Calling `skip_binder` is okay, because `add_bounds` expects the `param_ty`
10361078
// parameter to have a skipped binder.
10371079
let param_ty = tcx.mk_projection(assoc_ty.def_id, candidate.skip_binder().substs);
1038-
self.add_bounds(param_ty, ast_bounds, bounds);
1080+
let ast_bounds: Vec<_> = ast_bounds.iter().collect();
1081+
self.add_bounds(param_ty, &ast_bounds, bounds);
10391082
}
10401083
}
10411084
Ok(())
@@ -1352,21 +1395,24 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
13521395
ty_param_def_id, assoc_name, span,
13531396
);
13541397

1355-
let predicates =
1356-
&self.get_type_parameter_bounds(span, ty_param_def_id.to_def_id()).predicates;
1398+
let predicates = &self
1399+
.get_type_parameter_bounds(span, ty_param_def_id.to_def_id(), assoc_name)
1400+
.predicates;
13571401

13581402
debug!("find_bound_for_assoc_item: predicates={:#?}", predicates);
13591403

13601404
let param_hir_id = tcx.hir().local_def_id_to_hir_id(ty_param_def_id);
13611405
let param_name = tcx.hir().ty_param_name(param_hir_id);
13621406
self.one_bound_for_assoc_type(
13631407
|| {
1364-
traits::transitive_bounds(
1408+
traits::transitive_bounds_that_define_assoc_type(
13651409
tcx,
13661410
predicates.iter().filter_map(|(p, _)| {
13671411
p.to_opt_poly_trait_ref().map(|trait_ref| trait_ref.value)
13681412
}),
1413+
assoc_name,
13691414
)
1415+
.into_iter()
13701416
},
13711417
|| param_name.to_string(),
13721418
assoc_name,

compiler/rustc_typeck/src/check/fn_ctxt/mod.rs

+7-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use rustc_middle::ty::fold::TypeFoldable;
2020
use rustc_middle::ty::subst::GenericArgKind;
2121
use rustc_middle::ty::{self, Const, Ty, TyCtxt};
2222
use rustc_session::Session;
23+
use rustc_span::symbol::Ident;
2324
use rustc_span::{self, Span};
2425
use rustc_trait_selection::traits::{ObligationCause, ObligationCauseCode};
2526

@@ -183,7 +184,12 @@ impl<'a, 'tcx> AstConv<'tcx> for FnCtxt<'a, 'tcx> {
183184
}
184185
}
185186

186-
fn get_type_parameter_bounds(&self, _: Span, def_id: DefId) -> ty::GenericPredicates<'tcx> {
187+
fn get_type_parameter_bounds(
188+
&self,
189+
_: Span,
190+
def_id: DefId,
191+
_: Ident,
192+
) -> ty::GenericPredicates<'tcx> {
187193
let tcx = self.tcx;
188194
let hir_id = tcx.hir().local_def_id_to_hir_id(def_id.expect_local());
189195
let item_id = tcx.hir().ty_param_owner(hir_id);

0 commit comments

Comments
 (0)