Skip to content

Commit a885bb4

Browse files
authored
Fix cycle head durability (#1024)
1 parent 05a9af7 commit a885bb4

File tree

3 files changed

+101
-11
lines changed

3 files changed

+101
-11
lines changed

src/function/execute.rs

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ where
170170

171171
// Our provisional value from the previous iteration, when doing fixpoint iteration.
172172
// This is different from `opt_old_memo` which might be from a different revision.
173-
let mut last_provisional_memo: Option<&Memo<'db, C>> = None;
173+
let mut last_provisional_memo_opt: Option<&Memo<'db, C>> = None;
174174

175175
// TODO: Can we seed those somehow?
176176
let mut last_stale_tracked_ids: Vec<(Identity, Id)> = Vec::new();
@@ -194,7 +194,7 @@ where
194194
// Only use the last provisional memo if it was a cycle head in the last iteration. This is to
195195
// force at least two executions.
196196
if old_memo.cycle_heads().contains(&database_key_index) {
197-
last_provisional_memo = Some(old_memo);
197+
last_provisional_memo_opt = Some(old_memo);
198198
}
199199

200200
iteration_count = old_memo.revisions.iteration();
@@ -219,7 +219,7 @@ where
219219
db,
220220
zalsa,
221221
active_query,
222-
last_provisional_memo.or(opt_old_memo),
222+
last_provisional_memo_opt.or(opt_old_memo),
223223
);
224224

225225
// Take the cycle heads to not-fight-rust's-borrow-checker.
@@ -329,10 +329,7 @@ where
329329

330330
// Get the last provisional value for this query so that we can compare it with the new value
331331
// to test if the cycle converged.
332-
let last_provisional_value = if let Some(last_provisional) = last_provisional_memo {
333-
// We have a last provisional value from our previous time around the loop.
334-
last_provisional.value.as_ref()
335-
} else {
332+
let last_provisional_memo = last_provisional_memo_opt.unwrap_or_else(|| {
336333
// This is our first time around the loop; a provisional value must have been
337334
// inserted into the memo table when the cycle was hit, so let's pull our
338335
// initial provisional value from there.
@@ -346,8 +343,10 @@ where
346343
});
347344

348345
debug_assert!(memo.may_be_provisional());
349-
memo.value.as_ref()
350-
};
346+
memo
347+
});
348+
349+
let last_provisional_value = last_provisional_memo.value.as_ref();
351350

352351
let last_provisional_value = last_provisional_value.expect(
353352
"`fetch_cold_cycle` should have inserted a provisional memo with Cycle::initial",
@@ -389,9 +388,24 @@ where
389388
}
390389
}
391390

392-
let this_converged = C::values_equal(&new_value, last_provisional_value);
393391
let mut completed_query = active_query.pop();
394392

393+
let value_converged = C::values_equal(&new_value, last_provisional_value);
394+
395+
// It's important to force a re-execution of the cycle if `changed_at` or `durability` has changed
396+
// to ensure the reduced durability and changed propagates to all queries depending on this head.
397+
let metadata_converged = last_provisional_memo.revisions.durability
398+
== completed_query.revisions.durability
399+
&& last_provisional_memo.revisions.changed_at
400+
== completed_query.revisions.changed_at
401+
&& last_provisional_memo
402+
.revisions
403+
.origin
404+
.is_derived_untracked()
405+
== completed_query.revisions.origin.is_derived_untracked();
406+
407+
let this_converged = value_converged && metadata_converged;
408+
395409
if let Some(outer_cycle) = outer_cycle {
396410
tracing::info!(
397411
"Detected nested cycle {database_key_index:?}, iterate it as part of the outer cycle {outer_cycle:?}"
@@ -494,7 +508,7 @@ where
494508
memo_ingredient_index,
495509
);
496510

497-
last_provisional_memo = Some(new_memo);
511+
last_provisional_memo_opt = Some(new_memo);
498512

499513
last_stale_tracked_ids = completed_query.stale_tracked_structs;
500514

src/zalsa_local.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -934,6 +934,10 @@ impl QueryOrigin {
934934
}
935935
}
936936

937+
pub fn is_derived_untracked(&self) -> bool {
938+
matches!(self.kind, QueryOriginKind::DerivedUntracked)
939+
}
940+
937941
/// Create a query origin of type `QueryOriginKind::Derived`, with the given edges.
938942
pub fn derived(input_outputs: Box<[QueryEdge]>) -> QueryOrigin {
939943
// Exceeding `u32::MAX` query edges should never happen in real-world usage.
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#![cfg(feature = "inventory")]
2+
3+
//! Tests that the durability correctly propagates
4+
//! to all cycle heads.
5+
6+
use salsa::Setter as _;
7+
8+
#[test_log::test]
9+
fn low_durability_cycle_enter_from_different_head() {
10+
let mut db = MyDbImpl::default();
11+
// Start with 0, the same as returned by cycle initial
12+
let input = Input::builder(0).new(&db);
13+
db.input = Some(input);
14+
15+
assert_eq!(query_a(&db), 0); // Prime the Db
16+
17+
input.set_value(&mut db).to(10);
18+
19+
assert_eq!(query_b(&db), 10);
20+
}
21+
22+
#[salsa::input]
23+
struct Input {
24+
value: u32,
25+
}
26+
27+
#[salsa::db]
28+
trait MyDb: salsa::Database {
29+
fn input(&self) -> Input;
30+
}
31+
32+
#[salsa::db]
33+
#[derive(Clone, Default)]
34+
struct MyDbImpl {
35+
storage: salsa::Storage<Self>,
36+
input: Option<Input>,
37+
}
38+
39+
#[salsa::db]
40+
impl salsa::Database for MyDbImpl {}
41+
42+
#[salsa::db]
43+
impl MyDb for MyDbImpl {
44+
fn input(&self) -> Input {
45+
self.input.unwrap()
46+
}
47+
}
48+
49+
#[salsa::tracked(cycle_initial=cycle_initial)]
50+
fn query_a(db: &dyn MyDb) -> u32 {
51+
query_b(db);
52+
db.input().value(db)
53+
}
54+
55+
fn cycle_initial(_db: &dyn MyDb, _id: salsa::Id) -> u32 {
56+
0
57+
}
58+
59+
#[salsa::interned]
60+
struct Interned {
61+
value: u32,
62+
}
63+
64+
#[salsa::tracked(cycle_initial=cycle_initial)]
65+
fn query_b<'db>(db: &'db dyn MyDb) -> u32 {
66+
query_c(db)
67+
}
68+
69+
#[salsa::tracked]
70+
fn query_c(db: &dyn MyDb) -> u32 {
71+
query_a(db)
72+
}

0 commit comments

Comments
 (0)