Skip to content

Commit 90afc07

Browse files
committed
Use a ParamEnvAnd<Predicate> for caching in ObligationForest
Previously, we used a plain `Predicate` to cache results (e.g. successes and failures) in ObligationForest. However, fulfillment depends on the precise `ParamEnv` used, so this is unsound in general. This commit changes the impl of `ForestObligation` for `PendingPredicateObligation` to use `ParamEnvAnd<Predicate>` instead of `Predicate` for the associated type. The associated type and method are renamed from 'predicate' to 'cache_key' to reflect the fact that type is no longer just a predicate.
1 parent d1e594f commit 90afc07

File tree

4 files changed

+26
-18
lines changed

4 files changed

+26
-18
lines changed

src/librustc/traits/fulfill.rs

+6-3
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,13 @@ use super::{FulfillmentError, FulfillmentErrorCode};
1818
use super::{ObligationCause, PredicateObligation};
1919

2020
impl<'tcx> ForestObligation for PendingPredicateObligation<'tcx> {
21-
type Predicate = ty::Predicate<'tcx>;
21+
/// Note that we include both the `ParamEnv` and the `Predicate`,
22+
/// as the `ParamEnv` can influence whether fulfillment succeeds
23+
/// or fails.
24+
type CacheKey = ty::ParamEnvAnd<'tcx, ty::Predicate<'tcx>>;
2225

23-
fn as_predicate(&self) -> &Self::Predicate {
24-
&self.obligation.predicate
26+
fn as_cache_key(&self) -> Self::CacheKey {
27+
self.obligation.param_env.and(self.obligation.predicate)
2528
}
2629
}
2730

src/librustc_data_structures/obligation_forest/graphviz.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ impl<'a, O: ForestObligation + 'a> dot::Labeller<'a> for &'a ObligationForest<O>
5151

5252
fn node_label(&self, index: &Self::Node) -> dot::LabelText<'_> {
5353
let node = &self.nodes[*index];
54-
let label = format!("{:?} ({:?})", node.obligation.as_predicate(), node.state.get());
54+
let label = format!("{:?} ({:?})", node.obligation.as_cache_key(), node.state.get());
5555

5656
dot::LabelText::LabelStr(label.into())
5757
}

src/librustc_data_structures/obligation_forest/mod.rs

+17-12
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,13 @@ mod graphviz;
8686
mod tests;
8787

8888
pub trait ForestObligation: Clone + Debug {
89-
type Predicate: Clone + hash::Hash + Eq + Debug;
89+
type CacheKey: Clone + hash::Hash + Eq + Debug;
9090

91-
fn as_predicate(&self) -> &Self::Predicate;
91+
/// Converts this `ForestObligation` suitable for use as a cache key.
92+
/// If two distinct `ForestObligations`s return the same cache key,
93+
/// then it must be sound to use the result of processing one obligation
94+
/// (e.g. success for error) for the other obligation
95+
fn as_cache_key(&self) -> Self::CacheKey;
9296
}
9397

9498
pub trait ObligationProcessor {
@@ -138,12 +142,12 @@ pub struct ObligationForest<O: ForestObligation> {
138142
nodes: Vec<Node<O>>,
139143

140144
/// A cache of predicates that have been successfully completed.
141-
done_cache: FxHashSet<O::Predicate>,
145+
done_cache: FxHashSet<O::CacheKey>,
142146

143147
/// A cache of the nodes in `nodes`, indexed by predicate. Unfortunately,
144148
/// its contents are not guaranteed to match those of `nodes`. See the
145149
/// comments in `process_obligation` for details.
146-
active_cache: FxHashMap<O::Predicate, usize>,
150+
active_cache: FxHashMap<O::CacheKey, usize>,
147151

148152
/// A vector reused in compress(), to avoid allocating new vectors.
149153
node_rewrites: RefCell<Vec<usize>>,
@@ -157,7 +161,7 @@ pub struct ObligationForest<O: ForestObligation> {
157161
/// See [this][details] for details.
158162
///
159163
/// [details]: https://github.com/rust-lang/rust/pull/53255#issuecomment-421184780
160-
error_cache: FxHashMap<ObligationTreeId, FxHashSet<O::Predicate>>,
164+
error_cache: FxHashMap<ObligationTreeId, FxHashSet<O::CacheKey>>,
161165
}
162166

163167
#[derive(Debug)]
@@ -305,11 +309,12 @@ impl<O: ForestObligation> ObligationForest<O> {
305309

306310
// Returns Err(()) if we already know this obligation failed.
307311
fn register_obligation_at(&mut self, obligation: O, parent: Option<usize>) -> Result<(), ()> {
308-
if self.done_cache.contains(obligation.as_predicate()) {
312+
if self.done_cache.contains(&obligation.as_cache_key()) {
313+
debug!("register_obligation_at: ignoring already done obligation: {:?}", obligation);
309314
return Ok(());
310315
}
311316

312-
match self.active_cache.entry(obligation.as_predicate().clone()) {
317+
match self.active_cache.entry(obligation.as_cache_key().clone()) {
313318
Entry::Occupied(o) => {
314319
let node = &mut self.nodes[*o.get()];
315320
if let Some(parent_index) = parent {
@@ -333,7 +338,7 @@ impl<O: ForestObligation> ObligationForest<O> {
333338
&& self
334339
.error_cache
335340
.get(&obligation_tree_id)
336-
.map(|errors| errors.contains(obligation.as_predicate()))
341+
.map(|errors| errors.contains(&obligation.as_cache_key()))
337342
.unwrap_or(false);
338343

339344
if already_failed {
@@ -380,7 +385,7 @@ impl<O: ForestObligation> ObligationForest<O> {
380385
self.error_cache
381386
.entry(node.obligation_tree_id)
382387
.or_default()
383-
.insert(node.obligation.as_predicate().clone());
388+
.insert(node.obligation.as_cache_key().clone());
384389
}
385390

386391
/// Performs a pass through the obligation list. This must
@@ -618,11 +623,11 @@ impl<O: ForestObligation> ObligationForest<O> {
618623
// `self.nodes`. See the comment in `process_obligation`
619624
// for more details.
620625
if let Some((predicate, _)) =
621-
self.active_cache.remove_entry(node.obligation.as_predicate())
626+
self.active_cache.remove_entry(&node.obligation.as_cache_key())
622627
{
623628
self.done_cache.insert(predicate);
624629
} else {
625-
self.done_cache.insert(node.obligation.as_predicate().clone());
630+
self.done_cache.insert(node.obligation.as_cache_key().clone());
626631
}
627632
if do_completed == DoCompleted::Yes {
628633
// Extract the success stories.
@@ -635,7 +640,7 @@ impl<O: ForestObligation> ObligationForest<O> {
635640
// We *intentionally* remove the node from the cache at this point. Otherwise
636641
// tests must come up with a different type on every type error they
637642
// check against.
638-
self.active_cache.remove(node.obligation.as_predicate());
643+
self.active_cache.remove(&node.obligation.as_cache_key());
639644
self.insert_into_error_cache(index);
640645
node_rewrites[index] = orig_nodes_len;
641646
dead_nodes += 1;

src/librustc_data_structures/obligation_forest/tests.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ use std::fmt;
44
use std::marker::PhantomData;
55

66
impl<'a> super::ForestObligation for &'a str {
7-
type Predicate = &'a str;
7+
type CacheKey = &'a str;
88

9-
fn as_predicate(&self) -> &Self::Predicate {
9+
fn as_cache_key(&self) -> Self::CacheKey {
1010
self
1111
}
1212
}

0 commit comments

Comments
 (0)