Skip to content

Commit 2ca4964

Browse files
committed
Allow to self reference associated types in where clauses
1 parent 24dcf6f commit 2ca4964

File tree

8 files changed

+204
-63
lines changed

8 files changed

+204
-63
lines changed

compiler/rustc_infer/src/traits/util.rs

+35
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use crate::traits::{Obligation, ObligationCause, PredicateObligation};
44
use rustc_data_structures::fx::FxHashSet;
55
use rustc_middle::ty::outlives::Component;
66
use rustc_middle::ty::{self, ToPredicate, TyCtxt, WithConstness};
7+
use rustc_span::symbol::Ident;
78

89
pub fn anonymize_predicate<'tcx>(
910
tcx: TyCtxt<'tcx>,
@@ -89,6 +90,32 @@ pub fn elaborate_trait_refs<'tcx>(
8990
elaborate_predicates(tcx, predicates)
9091
}
9192

93+
pub fn elaborate_trait_refs_that_define_assoc_type<'tcx>(
94+
tcx: TyCtxt<'tcx>,
95+
trait_refs: impl Iterator<Item = ty::PolyTraitRef<'tcx>>,
96+
assoc_name: Ident,
97+
) -> FxHashSet<ty::PolyTraitRef<'tcx>> {
98+
let mut stack: Vec<_> = trait_refs.collect();
99+
let mut trait_refs = FxHashSet::default();
100+
101+
while let Some(trait_ref) = stack.pop() {
102+
if trait_refs.insert(trait_ref) {
103+
let super_predicates =
104+
tcx.super_predicates_that_define_assoc_type((trait_ref.def_id(), Some(assoc_name)));
105+
for (super_predicate, _) in super_predicates.predicates {
106+
let bound_predicate = super_predicate.bound_atom();
107+
let subst_predicate = super_predicate
108+
.subst_supertrait(tcx, &bound_predicate.rebind(trait_ref.skip_binder()));
109+
if let Some(binder) = subst_predicate.to_opt_poly_trait_ref() {
110+
stack.push(binder.value);
111+
}
112+
}
113+
}
114+
}
115+
116+
trait_refs
117+
}
118+
92119
pub fn elaborate_predicates<'tcx>(
93120
tcx: TyCtxt<'tcx>,
94121
predicates: impl Iterator<Item = ty::Predicate<'tcx>>,
@@ -287,6 +314,14 @@ pub fn transitive_bounds<'tcx>(
287314
elaborate_trait_refs(tcx, bounds).filter_to_traits()
288315
}
289316

317+
pub fn transitive_bounds_that_define_assoc_type<'tcx>(
318+
tcx: TyCtxt<'tcx>,
319+
bounds: impl Iterator<Item = ty::PolyTraitRef<'tcx>>,
320+
assoc_name: Ident,
321+
) -> FxHashSet<ty::PolyTraitRef<'tcx>> {
322+
elaborate_trait_refs_that_define_assoc_type(tcx, bounds, assoc_name)
323+
}
324+
290325
///////////////////////////////////////////////////////////////////////////
291326
// Other
292327
///////////////////////////////////////////////////////////////////////////

compiler/rustc_middle/src/query/mod.rs

+10
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,16 @@ rustc_queries! {
436436
desc { |tcx| "computing the supertraits of `{}`", tcx.def_path_str(key) }
437437
}
438438

439+
/// Maps from the `DefId` of a trait to the list of
440+
/// super-predicates. This is a subset of the full list of
441+
/// predicates. We store these in a separate map because we must
442+
/// evaluate them even during type conversion, often before the
443+
/// full predicates are available (note that supertraits have
444+
/// additional acyclicity requirements).
445+
query super_predicates_that_define_assoc_type(key: (DefId, Option<rustc_span::symbol::Ident>)) -> ty::GenericPredicates<'tcx> {
446+
desc { |tcx| "computing the supertraits of `{}`", tcx.def_path_str(key.0) }
447+
}
448+
439449
/// To avoid cycles within the predicates of a single item we compute
440450
/// per-type-parameter predicates for resolving `T::AssocTy`.
441451
query type_param_predicates(key: (DefId, LocalDefId, rustc_span::symbol::Ident)) -> ty::GenericPredicates<'tcx> {

compiler/rustc_middle/src/ty/query/keys.rs

+11
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,17 @@ impl Key for (LocalDefId, DefId) {
149149
}
150150
}
151151

152+
impl Key for (DefId, Option<Ident>) {
153+
type CacheSelector = DefaultCacheSelector;
154+
155+
fn query_crate(&self) -> CrateNum {
156+
self.0.krate
157+
}
158+
fn default_span(&self, tcx: TyCtxt<'_>) -> Span {
159+
tcx.def_span(self.0)
160+
}
161+
}
162+
152163
impl Key for (DefId, LocalDefId, Ident) {
153164
type CacheSelector = DefaultCacheSelector;
154165

compiler/rustc_trait_selection/src/traits/mod.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ pub use self::util::{
6565
get_vtable_index_of_object_method, impl_item_is_final, predicate_for_trait_def, upcast_choices,
6666
};
6767
pub use self::util::{
68-
supertrait_def_ids, supertraits, transitive_bounds, SupertraitDefIds, Supertraits,
68+
supertrait_def_ids, supertraits, transitive_bounds, transitive_bounds_that_define_assoc_type,
69+
SupertraitDefIds, Supertraits,
6970
};
7071

7172
pub use self::chalk_fulfill::FulfillmentContext as ChalkFulfillmentContext;

compiler/rustc_typeck/src/astconv/mod.rs

+43-6
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ mod errors;
66
mod generics;
77

88
use crate::bounds::Bounds;
9+
use crate::collect::super_traits_of;
910
use crate::collect::PlaceholderHirTyCollector;
1011
use crate::errors::{
1112
AmbiguousLifetimeBound, MultipleRelaxedDefaultBounds, TraitObjectDeclaredWithNoTraits,
@@ -768,7 +769,7 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
768769
}
769770

770771
// Returns `true` if a bounds list includes `?Sized`.
771-
pub fn is_unsized(&self, ast_bounds: &[hir::GenericBound<'_>], span: Span) -> bool {
772+
pub fn is_unsized(&self, ast_bounds: &[&hir::GenericBound<'_>], span: Span) -> bool {
772773
let tcx = self.tcx();
773774

774775
// Try to find an unbound in bounds.
@@ -826,7 +827,7 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
826827
fn add_bounds(
827828
&self,
828829
param_ty: Ty<'tcx>,
829-
ast_bounds: &[hir::GenericBound<'_>],
830+
ast_bounds: &[&hir::GenericBound<'_>],
830831
bounds: &mut Bounds<'tcx>,
831832
) {
832833
let mut trait_bounds = Vec::new();
@@ -844,7 +845,7 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
844845
hir::GenericBound::Trait(_, hir::TraitBoundModifier::Maybe) => {}
845846
hir::GenericBound::LangItemTrait(lang_item, span, hir_id, args) => self
846847
.instantiate_lang_item_trait_ref(
847-
lang_item, span, hir_id, args, param_ty, bounds,
848+
*lang_item, *span, *hir_id, args, param_ty, bounds,
848849
),
849850
hir::GenericBound::Outlives(ref l) => region_bounds.push(l),
850851
}
@@ -878,7 +879,7 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
878879
pub fn compute_bounds(
879880
&self,
880881
param_ty: Ty<'tcx>,
881-
ast_bounds: &[hir::GenericBound<'_>],
882+
ast_bounds: &[&hir::GenericBound<'_>],
882883
sized_by_default: SizedByDefault,
883884
span: Span,
884885
) -> Bounds<'tcx> {
@@ -896,6 +897,39 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
896897
bounds
897898
}
898899

900+
pub fn compute_bounds_that_match_assoc_type(
901+
&self,
902+
param_ty: Ty<'tcx>,
903+
ast_bounds: &[hir::GenericBound<'_>],
904+
sized_by_default: SizedByDefault,
905+
span: Span,
906+
assoc_name: Ident,
907+
) -> Bounds<'tcx> {
908+
let mut result = Vec::new();
909+
910+
for ast_bound in ast_bounds {
911+
if let Some(trait_ref) = ast_bound.trait_ref() {
912+
if let Some(trait_did) = trait_ref.trait_def_id() {
913+
if super_traits_of(self.tcx(), trait_did).any(|trait_did| {
914+
self.tcx()
915+
.associated_items(trait_did)
916+
.find_by_name_and_kind(
917+
self.tcx(),
918+
assoc_name,
919+
ty::AssocKind::Type,
920+
trait_did,
921+
)
922+
.is_some()
923+
}) {
924+
result.push(ast_bound);
925+
}
926+
}
927+
}
928+
}
929+
930+
self.compute_bounds(param_ty, &result, sized_by_default, span)
931+
}
932+
899933
/// Given an HIR binding like `Item = Foo` or `Item: Foo`, pushes the corresponding predicates
900934
/// onto `bounds`.
901935
///
@@ -1050,7 +1084,8 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
10501084
// Calling `skip_binder` is okay, because `add_bounds` expects the `param_ty`
10511085
// parameter to have a skipped binder.
10521086
let param_ty = tcx.mk_projection(assoc_ty.def_id, candidate.skip_binder().substs);
1053-
self.add_bounds(param_ty, ast_bounds, bounds);
1087+
let ast_bounds: Vec<_> = ast_bounds.iter().collect();
1088+
self.add_bounds(param_ty, &ast_bounds, bounds);
10541089
}
10551090
}
10561091
Ok(())
@@ -1377,12 +1412,14 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
13771412
let param_name = tcx.hir().ty_param_name(param_hir_id);
13781413
self.one_bound_for_assoc_type(
13791414
|| {
1380-
traits::transitive_bounds(
1415+
traits::transitive_bounds_that_define_assoc_type(
13811416
tcx,
13821417
predicates.iter().filter_map(|(p, _)| {
13831418
p.to_opt_poly_trait_ref().map(|trait_ref| trait_ref.value)
13841419
}),
1420+
assoc_name,
13851421
)
1422+
.into_iter()
13861423
},
13871424
|| param_name.to_string(),
13881425
assoc_name,

compiler/rustc_typeck/src/collect.rs

+87-54
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
// ignore-tidy-filelength
12
//! "Collection" is the process of determining the type and other external
23
//! details of each item in Rust. Collection is specifically concerned
34
//! with *inter-procedural* things -- for example, for a function
@@ -79,6 +80,7 @@ pub fn provide(providers: &mut Providers) {
7980
projection_ty_from_predicates,
8081
explicit_predicates_of,
8182
super_predicates_of,
83+
super_predicates_that_define_assoc_type,
8284
trait_explicit_predicates_and_bounds,
8385
type_param_predicates,
8486
trait_def,
@@ -651,17 +653,10 @@ impl ItemCtxt<'tcx> {
651653
hir::GenericBound::Trait(poly_trait_ref, _) => {
652654
let trait_ref = &poly_trait_ref.trait_ref;
653655
let trait_did = trait_ref.trait_def_id().unwrap();
654-
let traits_did = super_traits_of(self.tcx, trait_did);
655-
656-
traits_did.iter().any(|trait_did| {
656+
super_traits_of(self.tcx, trait_did).any(|trait_did| {
657657
self.tcx
658-
.associated_items(*trait_did)
659-
.find_by_name_and_kind(
660-
self.tcx,
661-
assoc_name,
662-
ty::AssocKind::Type,
663-
*trait_did,
664-
)
658+
.associated_items(trait_did)
659+
.find_by_name_and_kind(self.tcx, assoc_name, ty::AssocKind::Type, trait_did)
665660
.is_some()
666661
})
667662
}
@@ -1035,55 +1030,91 @@ fn adt_def(tcx: TyCtxt<'_>, def_id: DefId) -> &ty::AdtDef {
10351030
/// the transitive super-predicates are converted.
10361031
fn super_predicates_of(tcx: TyCtxt<'_>, trait_def_id: DefId) -> ty::GenericPredicates<'_> {
10371032
debug!("super_predicates(trait_def_id={:?})", trait_def_id);
1038-
let trait_hir_id = tcx.hir().local_def_id_to_hir_id(trait_def_id.expect_local());
1033+
tcx.super_predicates_that_define_assoc_type((trait_def_id, None))
1034+
}
10391035

1040-
let item = match tcx.hir().get(trait_hir_id) {
1041-
Node::Item(item) => item,
1042-
_ => bug!("trait_node_id {} is not an item", trait_hir_id),
1043-
};
1036+
/// Ensures that the super-predicates of the trait with a `DefId`
1037+
/// of `trait_def_id` are converted and stored. This also ensures that
1038+
/// the transitive super-predicates are converted.
1039+
fn super_predicates_that_define_assoc_type(
1040+
tcx: TyCtxt<'_>,
1041+
(trait_def_id, assoc_name): (DefId, Option<Ident>),
1042+
) -> ty::GenericPredicates<'_> {
1043+
debug!(
1044+
"super_predicates_that_define_assoc_type(trait_def_id={:?}, assoc_name={:?})",
1045+
trait_def_id, assoc_name
1046+
);
1047+
if trait_def_id.is_local() {
1048+
debug!("super_predicates_that_define_assoc_type: local trait_def_id={:?}", trait_def_id);
1049+
let trait_hir_id = tcx.hir().local_def_id_to_hir_id(trait_def_id.expect_local());
10441050

1045-
let (generics, bounds) = match item.kind {
1046-
hir::ItemKind::Trait(.., ref generics, ref supertraits, _) => (generics, supertraits),
1047-
hir::ItemKind::TraitAlias(ref generics, ref supertraits) => (generics, supertraits),
1048-
_ => span_bug!(item.span, "super_predicates invoked on non-trait"),
1049-
};
1051+
let item = match tcx.hir().get(trait_hir_id) {
1052+
Node::Item(item) => item,
1053+
_ => bug!("trait_node_id {} is not an item", trait_hir_id),
1054+
};
10501055

1051-
let icx = ItemCtxt::new(tcx, trait_def_id);
1052-
1053-
// Convert the bounds that follow the colon, e.g., `Bar + Zed` in `trait Foo: Bar + Zed`.
1054-
let self_param_ty = tcx.types.self_param;
1055-
let superbounds1 =
1056-
AstConv::compute_bounds(&icx, self_param_ty, bounds, SizedByDefault::No, item.span);
1057-
1058-
let superbounds1 = superbounds1.predicates(tcx, self_param_ty);
1059-
1060-
// Convert any explicit superbounds in the where-clause,
1061-
// e.g., `trait Foo where Self: Bar`.
1062-
// In the case of trait aliases, however, we include all bounds in the where-clause,
1063-
// so e.g., `trait Foo = where u32: PartialEq<Self>` would include `u32: PartialEq<Self>`
1064-
// as one of its "superpredicates".
1065-
let is_trait_alias = tcx.is_trait_alias(trait_def_id);
1066-
let superbounds2 = icx.type_parameter_bounds_in_generics(
1067-
generics,
1068-
item.hir_id,
1069-
self_param_ty,
1070-
OnlySelfBounds(!is_trait_alias),
1071-
None,
1072-
);
1056+
let (generics, bounds) = match item.kind {
1057+
hir::ItemKind::Trait(.., ref generics, ref supertraits, _) => (generics, supertraits),
1058+
hir::ItemKind::TraitAlias(ref generics, ref supertraits) => (generics, supertraits),
1059+
_ => span_bug!(item.span, "super_predicates invoked on non-trait"),
1060+
};
10731061

1074-
// Combine the two lists to form the complete set of superbounds:
1075-
let superbounds = &*tcx.arena.alloc_from_iter(superbounds1.into_iter().chain(superbounds2));
1062+
let icx = ItemCtxt::new(tcx, trait_def_id);
10761063

1077-
// Now require that immediate supertraits are converted,
1078-
// which will, in turn, reach indirect supertraits.
1079-
for &(pred, span) in superbounds {
1080-
debug!("superbound: {:?}", pred);
1081-
if let ty::PredicateAtom::Trait(bound, _) = pred.skip_binders() {
1082-
tcx.at(span).super_predicates_of(bound.def_id());
1064+
// Convert the bounds that follow the colon, e.g., `Bar + Zed` in `trait Foo: Bar + Zed`.
1065+
let self_param_ty = tcx.types.self_param;
1066+
let superbounds1 = if let Some(assoc_name) = assoc_name {
1067+
AstConv::compute_bounds_that_match_assoc_type(
1068+
&icx,
1069+
self_param_ty,
1070+
&bounds,
1071+
SizedByDefault::No,
1072+
item.span,
1073+
assoc_name,
1074+
)
1075+
} else {
1076+
let bounds: Vec<_> = bounds.iter().collect();
1077+
AstConv::compute_bounds(&icx, self_param_ty, &bounds, SizedByDefault::No, item.span)
1078+
};
1079+
1080+
let superbounds1 = superbounds1.predicates(tcx, self_param_ty);
1081+
1082+
// Convert any explicit superbounds in the where-clause,
1083+
// e.g., `trait Foo where Self: Bar`.
1084+
// In the case of trait aliases, however, we include all bounds in the where-clause,
1085+
// so e.g., `trait Foo = where u32: PartialEq<Self>` would include `u32: PartialEq<Self>`
1086+
// as one of its "superpredicates".
1087+
let is_trait_alias = tcx.is_trait_alias(trait_def_id);
1088+
let superbounds2 = icx.type_parameter_bounds_in_generics(
1089+
generics,
1090+
item.hir_id,
1091+
self_param_ty,
1092+
OnlySelfBounds(!is_trait_alias),
1093+
assoc_name,
1094+
);
1095+
1096+
// Combine the two lists to form the complete set of superbounds:
1097+
let superbounds = &*tcx.arena.alloc_from_iter(superbounds1.into_iter().chain(superbounds2));
1098+
1099+
// Now require that immediate supertraits are converted,
1100+
// which will, in turn, reach indirect supertraits.
1101+
if assoc_name.is_none() {
1102+
// FIXME: move this into the `super_predicates_of` query
1103+
for &(pred, span) in superbounds {
1104+
debug!("superbound: {:?}", pred);
1105+
if let ty::PredicateAtom::Trait(bound, _) = pred.skip_binders() {
1106+
tcx.at(span).super_predicates_of(bound.def_id());
1107+
}
1108+
}
10831109
}
1084-
}
10851110

1086-
ty::GenericPredicates { parent: None, predicates: superbounds }
1111+
ty::GenericPredicates { parent: None, predicates: superbounds }
1112+
} else {
1113+
// if `assoc_name` is None, then the query should've been redirected to an
1114+
// external provider
1115+
assert!(assoc_name.is_some());
1116+
tcx.super_predicates_of(trait_def_id)
1117+
}
10871118
}
10881119

10891120
pub fn super_traits_of(tcx: TyCtxt<'_>, trait_def_id: DefId) -> impl Iterator<Item = DefId> {
@@ -1123,6 +1154,8 @@ pub fn super_traits_of(tcx: TyCtxt<'_>, trait_def_id: DefId) -> impl Iterator<It
11231154
}
11241155
}
11251156
}
1157+
1158+
set.into_iter()
11261159
}
11271160

11281161
fn trait_def(tcx: TyCtxt<'_>, def_id: DefId) -> ty::TraitDef {
@@ -1976,8 +2009,8 @@ fn gather_explicit_predicates_of(tcx: TyCtxt<'_>, def_id: DefId) -> ty::GenericP
19762009
index += 1;
19772010

19782011
let sized = SizedByDefault::Yes;
1979-
let bounds =
1980-
AstConv::compute_bounds(&icx, param_ty, &param.bounds, sized, param.span);
2012+
let bounds: Vec<_> = param.bounds.iter().collect();
2013+
let bounds = AstConv::compute_bounds(&icx, param_ty, &bounds, sized, param.span);
19812014
predicates.extend(bounds.predicates(tcx, param_ty));
19822015
}
19832016
GenericParamKind::Const { .. } => {

0 commit comments

Comments
 (0)