Skip to content

Commit 4a7c955

Browse files
authoredSep 22, 2024
Merge pull request #573 from dhruvmanila/dhruv/recreate-panic
Fix panic when recreating tracked struct that was deleted in previous revision
2 parents 198c43f + 8094e0c commit 4a7c955

8 files changed

+87
-23
lines changed
 

‎src/active_query.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use crate::{
77
key::{DatabaseKeyIndex, DependencyIndex},
88
tracked_struct::{Disambiguator, KeyStruct},
99
zalsa_local::EMPTY_DEPENDENCIES,
10-
Cycle, Id, Revision,
10+
Cycle, Revision,
1111
};
1212

1313
use super::zalsa_local::{EdgeKind, QueryEdges, QueryOrigin, QueryRevisions};
@@ -49,7 +49,7 @@ pub(crate) struct ActiveQuery {
4949

5050
/// Map from tracked struct keys (which include the hash + disambiguator) to their
5151
/// final id.
52-
pub(crate) tracked_struct_ids: FxHashMap<KeyStruct, Id>,
52+
pub(crate) tracked_struct_ids: FxHashMap<KeyStruct, DatabaseKeyIndex>,
5353

5454
/// Stores the values accumulated to the given ingredient.
5555
/// The type of accumulated value is erased but known to the ingredient.

‎src/function/diff_outputs.rs

+18-10
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,43 @@
1+
use super::{memo::Memo, Configuration, IngredientImpl};
12
use crate::{
23
hash::FxHashSet, key::DependencyIndex, zalsa_local::QueryRevisions, AsDynDatabase as _,
34
DatabaseKeyIndex, Event, EventKind,
45
};
56

6-
use super::{memo::Memo, Configuration, IngredientImpl};
7-
87
impl<C> IngredientImpl<C>
98
where
109
C: Configuration,
1110
{
1211
/// Compute the old and new outputs and invoke the `clear_stale_output` callback
1312
/// for each output that was generated before but is not generated now.
13+
///
14+
/// This function takes a `&mut` reference to `revisions` to remove outputs
15+
/// that no longer exist in this revision from [`QueryRevisions::tracked_struct_ids`].
1416
pub(super) fn diff_outputs(
1517
&self,
1618
db: &C::DbView,
1719
key: DatabaseKeyIndex,
1820
old_memo: &Memo<C::Output<'_>>,
19-
revisions: &QueryRevisions,
21+
revisions: &mut QueryRevisions,
2022
) {
2123
// Iterate over the outputs of the `old_memo` and put them into a hashset
22-
let mut old_outputs = FxHashSet::default();
23-
old_memo.revisions.origin.outputs().for_each(|i| {
24-
old_outputs.insert(i);
25-
});
24+
let mut old_outputs: FxHashSet<_> = old_memo.revisions.origin.outputs().collect();
2625

2726
// Iterate over the outputs of the current query
2827
// and remove elements from `old_outputs` when we find them
2928
for new_output in revisions.origin.outputs() {
30-
if old_outputs.contains(&new_output) {
31-
old_outputs.remove(&new_output);
32-
}
29+
old_outputs.remove(&new_output);
30+
}
31+
32+
if !old_outputs.is_empty() {
33+
// Remove the outputs that are no longer present in the current revision
34+
// to prevent that the next revision is seeded with a id mapping that no longer exists.
35+
revisions.tracked_struct_ids.retain(|_k, value| {
36+
!old_outputs.contains(&DependencyIndex {
37+
ingredient_index: value.ingredient_index,
38+
key_index: Some(value.key_index),
39+
})
40+
});
3341
}
3442

3543
for old_output in old_outputs {

‎src/function/execute.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ where
8080
// old value.
8181
if let Some(old_memo) = &opt_old_memo {
8282
self.backdate_if_appropriate(old_memo, &mut revisions, &value);
83-
self.diff_outputs(db, database_key_index, old_memo, &revisions);
83+
self.diff_outputs(db, database_key_index, old_memo, &mut revisions);
8484
}
8585

8686
tracing::debug!("{database_key_index:?}: read_upgrade: result.revisions = {revisions:#?}");

‎src/function/fetch.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ impl<C> IngredientImpl<C>
66
where
77
C: Configuration,
88
{
9-
pub fn fetch<'db>(&'db self, db: &'db C::DbView, id: Id) -> &C::Output<'db> {
9+
pub fn fetch<'db>(&'db self, db: &'db C::DbView, id: Id) -> &'db C::Output<'db> {
1010
let (zalsa, zalsa_local) = db.zalsas();
1111
zalsa_local.unwind_if_revision_cancelled(db.as_dyn_database());
1212

‎src/function/specify.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ where
7474

7575
if let Some(old_memo) = self.get_memo_from_table_for(zalsa, key) {
7676
self.backdate_if_appropriate(&old_memo, &mut revisions, &value);
77-
self.diff_outputs(db, database_key_index, &old_memo, &revisions);
77+
self.diff_outputs(db, database_key_index, &old_memo, &mut revisions);
7878
}
7979

8080
let memo = Memo {

‎src/tracked_struct.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,9 @@ where
276276
None => {
277277
// This is a new tracked struct, so create an entry in the struct map.
278278
let id = self.allocate(zalsa, zalsa_local, current_revision, &current_deps, fields);
279-
zalsa_local.add_output(self.database_key_index(id).into());
280-
zalsa_local.store_tracked_struct_id(key_struct, id);
279+
let key = self.database_key_index(id);
280+
zalsa_local.add_output(key.into());
281+
zalsa_local.store_tracked_struct_id(key_struct, key);
281282
C::struct_from_id(id)
282283
}
283284
}

‎src/zalsa_local.rs

+26-6
Original file line numberDiff line numberDiff line change
@@ -290,12 +290,15 @@ impl ZalsaLocal {
290290
);
291291
self.with_query_stack(|stack| {
292292
let top_query = stack.last().unwrap();
293-
top_query.tracked_struct_ids.get(key_struct).cloned()
293+
top_query
294+
.tracked_struct_ids
295+
.get(key_struct)
296+
.map(|index| index.key_index())
294297
})
295298
}
296299

297300
#[track_caller]
298-
pub(crate) fn store_tracked_struct_id(&self, key_struct: KeyStruct, id: Id) {
301+
pub(crate) fn store_tracked_struct_id(&self, key_struct: KeyStruct, id: DatabaseKeyIndex) {
299302
debug_assert!(
300303
self.query_in_progress(),
301304
"cannot create a tracked struct disambiguator outside of a tracked function"
@@ -358,9 +361,23 @@ pub(crate) struct QueryRevisions {
358361
pub(crate) origin: QueryOrigin,
359362

360363
/// The ids of tracked structs created by this query.
361-
/// This is used to seed the next round if the query is
362-
/// re-executed.
363-
pub(super) tracked_struct_ids: FxHashMap<KeyStruct, Id>,
364+
///
365+
/// This table plays an important role when queries are
366+
/// re-executed:
367+
/// * A clone of this field is used as the initial set of
368+
/// `TrackedStructId`s for the query on the next execution.
369+
/// * The query will thus re-use the same ids if it creates
370+
/// tracked structs with the same `KeyStruct` as before.
371+
/// It may also create new tracked structs.
372+
/// * One tricky case involves deleted structs. If
373+
/// the old revision created a struct S but the new
374+
/// revision did not, there will still be a map entry
375+
/// for S. This is because queries only ever grow the map
376+
/// and they start with the same entries as from the
377+
/// previous revision. To handle this, `diff_outputs` compares
378+
/// the structs from the old/new revision and retains
379+
/// only entries that appeared in the new revision.
380+
pub(super) tracked_struct_ids: FxHashMap<KeyStruct, DatabaseKeyIndex>,
364381

365382
pub(super) accumulated: AccumulatedMap,
366383
}
@@ -519,7 +536,10 @@ impl ActiveQueryGuard<'_> {
519536
}
520537

521538
/// Initialize the tracked struct ids with the values from the prior execution.
522-
pub(crate) fn seed_tracked_struct_ids(&self, tracked_struct_ids: &FxHashMap<KeyStruct, Id>) {
539+
pub(crate) fn seed_tracked_struct_ids(
540+
&self,
541+
tracked_struct_ids: &FxHashMap<KeyStruct, DatabaseKeyIndex>,
542+
) {
523543
self.local_state.with_query_stack(|stack| {
524544
assert_eq!(stack.len(), self.push_len);
525545
let frame = stack.last_mut().unwrap();
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
//! Test that re-creating a `tracked` struct after it was deleted in a previous
2+
//! revision doesn't panic.
3+
#![allow(warnings)]
4+
5+
use salsa::Setter;
6+
7+
#[salsa::input]
8+
struct MyInput {
9+
field: u32,
10+
}
11+
12+
#[salsa::tracked]
13+
struct TrackedStruct<'db> {
14+
field: u32,
15+
}
16+
17+
#[salsa::tracked]
18+
fn tracked_fn(db: &dyn salsa::Database, input: MyInput) -> Option<TrackedStruct<'_>> {
19+
if input.field(db) == 1 {
20+
Some(TrackedStruct::new(db, 1))
21+
} else {
22+
None
23+
}
24+
}
25+
26+
#[test]
27+
fn execute() {
28+
let mut db = salsa::DatabaseImpl::new();
29+
let input = MyInput::new(&db, 1);
30+
assert!(tracked_fn(&db, input).is_some());
31+
input.set_field(&mut db).to(0);
32+
assert_eq!(tracked_fn(&db, input), None);
33+
input.set_field(&mut db).to(1);
34+
assert!(tracked_fn(&db, input).is_some());
35+
}

0 commit comments

Comments
 (0)
Please sign in to comment.