Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/active_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ impl ActiveQuery {
.mark_all_active(active_tracked_ids.iter().copied());
}

pub(super) fn take_cycle_heads(&mut self) -> CycleHeads {
std::mem::take(&mut self.cycle_heads)
}

pub(super) fn add_read(
&mut self,
input: DatabaseKeyIndex,
Expand Down
4 changes: 4 additions & 0 deletions src/cycle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -490,4 +490,8 @@ impl<'db> ProvisionalStatus<'db> {
_ => empty_cycle_heads(),
}
}

pub(crate) const fn is_provisional(&self) -> bool {
matches!(self, ProvisionalStatus::Provisional { .. })
}
}
61 changes: 43 additions & 18 deletions src/function/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,20 +56,25 @@ where
});

let (new_value, mut completed_query) = match C::CYCLE_STRATEGY {
CycleRecoveryStrategy::Panic => Self::execute_query(
db,
zalsa,
zalsa_local.push_query(database_key_index, IterationCount::initial()),
opt_old_memo,
),
CycleRecoveryStrategy::Panic => {
let (new_value, active_query) = Self::execute_query(
db,
zalsa,
zalsa_local.push_query(database_key_index, IterationCount::initial()),
opt_old_memo,
);
(new_value, active_query.pop())
}
CycleRecoveryStrategy::FallbackImmediate => {
let (mut new_value, mut completed_query) = Self::execute_query(
let (mut new_value, active_query) = Self::execute_query(
db,
zalsa,
zalsa_local.push_query(database_key_index, IterationCount::initial()),
opt_old_memo,
);

let mut completed_query = active_query.pop();

if let Some(cycle_heads) = completed_query.revisions.cycle_heads_mut() {
// Did the new result we got depend on our own provisional value, in a cycle?
if cycle_heads.contains(&database_key_index) {
Expand Down Expand Up @@ -198,9 +203,10 @@ where

let _poison_guard =
PoisonProvisionalIfPanicking::new(self, zalsa, id, memo_ingredient_index);
let mut active_query = zalsa_local.push_query(database_key_index, iteration_count);

let (new_value, completed_query) = loop {
let active_query = zalsa_local.push_query(database_key_index, iteration_count);

// Tracked struct ids that existed in the previous revision
// but weren't recreated in the last iteration. It's important that we seed the next
// query with these ids because the query might re-create them as part of the next iteration.
Expand All @@ -209,29 +215,32 @@ where
// if they aren't recreated when reaching the final iteration.
active_query.seed_tracked_struct_ids(&last_stale_tracked_ids);

let (mut new_value, mut completed_query) = Self::execute_query(
let (mut new_value, mut active_query) = Self::execute_query(
db,
zalsa,
active_query,
last_provisional_memo.or(opt_old_memo),
);

// If there are no cycle heads, break out of the loop (`cycle_heads_mut` returns `None` if the cycle head list is empty)
let Some(cycle_heads) = completed_query.revisions.cycle_heads_mut() else {
// Take the cycle heads to not-fight-rust's-borrow-checker.
let mut cycle_heads = active_query.take_cycle_heads();

// If there are no cycle heads, break out of the loop.
if cycle_heads.is_empty() {
iteration_count = iteration_count.increment().unwrap_or_else(|| {
tracing::warn!("{database_key_index:?}: execute: too many cycle iterations");
panic!("{database_key_index:?}: execute: too many cycle iterations")
});

let mut completed_query = active_query.pop();
completed_query
.revisions
.update_iteration_count_mut(database_key_index, iteration_count);

claim_guard.set_release_mode(ReleaseMode::SelfOnly);
break (new_value, completed_query);
};
}

// Take the cycle heads to not-fight-rust's-borrow-checker.
let mut cycle_heads = std::mem::take(cycle_heads);
let mut missing_heads: SmallVec<[(DatabaseKeyIndex, IterationCount); 1]> =
SmallVec::new_const();
let mut max_iteration_count = iteration_count;
Expand Down Expand Up @@ -262,6 +271,11 @@ where
.provisional_status(zalsa, head.database_key_index.key_index())
.expect("cycle head memo must have been created during the execution");

// A cycle head isn't allowed to be final because that would mean that this query
// dependents on the initial value of B, but B finalized without
// completing this query that also participates in the cycle.
assert!(provisional_status.is_provisional());

for nested_head in provisional_status.cycle_heads() {
let nested_as_tuple = (
nested_head.database_key_index,
Expand Down Expand Up @@ -298,6 +312,8 @@ where
claim_guard.set_release_mode(ReleaseMode::SelfOnly);
}

let mut completed_query = active_query.pop();
*completed_query.revisions.verified_final.get_mut() = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand in what case this would be true and we need to reset it to false.

Copy link
Contributor Author

@MichaReiser MichaReiser Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QueryRevisions::new defaults to true if cycle heads are empty. The cycle heads are always empty when calling pop because of the preceding query_guard.take_cycle_heads

completed_query.revisions.set_cycle_heads(cycle_heads);

iteration_count = iteration_count.increment().unwrap_or_else(|| {
Expand Down Expand Up @@ -378,8 +394,17 @@ where
this_converged = C::values_equal(&new_value, last_provisional_value);
}
}

let new_cycle_heads = active_query.take_cycle_heads();
for head in new_cycle_heads {
if !cycle_heads.contains(&head.database_key_index) {
panic!("Cycle recovery function for {database_key_index:?} introduced a cycle, depending on {:?}. This is not allowed.", head.database_key_index);
}
}
}

let mut completed_query = active_query.pop();

if let Some(outer_cycle) = outer_cycle {
tracing::info!(
"Detected nested cycle {database_key_index:?}, iterate it as part of the outer cycle {outer_cycle:?}"
Expand All @@ -390,6 +415,7 @@ where
completed_query
.revisions
.set_cycle_converged(this_converged);
*completed_query.revisions.verified_final.get_mut() = false;

// Transfer ownership of this query to the outer cycle, so that it can claim it
// and other threads don't compete for the same lock.
Expand Down Expand Up @@ -428,9 +454,9 @@ where
}

*completed_query.revisions.verified_final.get_mut() = true;

break (new_value, completed_query);
}
*completed_query.revisions.verified_final.get_mut() = false;

// The fixpoint iteration hasn't converged. Iterate again...
iteration_count = iteration_count.increment().unwrap_or_else(|| {
Expand Down Expand Up @@ -484,7 +510,6 @@ where
last_provisional_memo = Some(new_memo);

last_stale_tracked_ids = completed_query.stale_tracked_structs;
active_query = zalsa_local.push_query(database_key_index, iteration_count);

continue;
};
Expand All @@ -503,7 +528,7 @@ where
zalsa: &'db Zalsa,
active_query: ActiveQueryGuard<'db>,
opt_old_memo: Option<&Memo<'db, C>>,
) -> (C::Output<'db>, CompletedQuery) {
) -> (C::Output<'db>, ActiveQueryGuard<'db>) {
if let Some(old_memo) = opt_old_memo {
// If we already executed this query once, then use the tracked-struct ids from the
// previous execution as the starting point for the new one.
Expand All @@ -528,7 +553,7 @@ where
C::id_to_input(zalsa, active_query.database_key_index.key_index()),
);

(new_value, active_query.pop())
(new_value, active_query)
}
}

Expand Down
5 changes: 4 additions & 1 deletion src/function/maybe_changed_after.rs
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,10 @@ where
cycle_heads.append_heads(&mut child_cycle_heads);

match input_result {
VerifyResult::Changed => return VerifyResult::changed(),
VerifyResult::Changed => {
cycle_heads.remove_head(database_key_index);
return VerifyResult::changed();
}
#[cfg(feature = "accumulator")]
VerifyResult::Unchanged { accumulated } => {
inputs |= accumulated;
Expand Down
12 changes: 12 additions & 0 deletions src/zalsa_local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1213,6 +1213,18 @@ impl ActiveQueryGuard<'_> {
}
}

pub(crate) fn take_cycle_heads(&mut self) -> CycleHeads {
// SAFETY: We do not access the query stack reentrantly.
unsafe {
self.local_state.with_query_stack_unchecked_mut(|stack| {
#[cfg(debug_assertions)]
assert_eq!(stack.len(), self.push_len);
let frame = stack.last_mut().unwrap();
frame.take_cycle_heads()
})
}
}

/// Invoked when the query has successfully completed execution.
fn complete(self) -> CompletedQuery {
// SAFETY: We do not access the query stack reentrantly.
Expand Down
82 changes: 82 additions & 0 deletions tests/cycle_recovery_dependencies.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#![cfg(feature = "inventory")]

//! Queries or inputs read within the cycle recovery function
//! are tracked on the cycle function and don't "leak" into the
//! function calling the query with cycle handling.
use expect_test::expect;
use salsa::Setter as _;

use crate::common::LogDatabase;

mod common;

#[salsa::input]
struct Input {
value: u32,
}

#[salsa::tracked]
fn entry(db: &dyn salsa::Database, input: Input) -> u32 {
query(db, input)
}

#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)]
fn query(db: &dyn salsa::Database, input: Input) -> u32 {
let val = query(db, input);
if val < 5 {
val + 1
} else {
val
}
}

fn cycle_initial(_db: &dyn salsa::Database, _id: salsa::Id, _input: Input) -> u32 {
0
}

fn cycle_fn(
db: &dyn salsa::Database,
_id: salsa::Id,
_last_provisional_value: &u32,
_value: &u32,
_count: u32,
input: Input,
) -> salsa::CycleRecoveryAction<u32> {
let _input = input.value(db);
salsa::CycleRecoveryAction::Iterate
}

#[test_log::test]
fn the_test() {
let mut db = common::EventLoggerDatabase::default();

let input = Input::new(&db, 1);
assert_eq!(entry(&db, input), 5);

db.assert_logs_len(15);

input.set_value(&mut db).to(2);

assert_eq!(entry(&db, input), 5);
db.assert_logs(expect![[r#"
[
"DidSetCancellationFlag",
"WillCheckCancellation",
"WillCheckCancellation",
"WillCheckCancellation",
"WillExecute { database_key: query(Id(0)) }",
"WillCheckCancellation",
"WillIterateCycle { database_key: query(Id(0)), iteration_count: IterationCount(1) }",
"WillCheckCancellation",
"WillIterateCycle { database_key: query(Id(0)), iteration_count: IterationCount(2) }",
"WillCheckCancellation",
"WillIterateCycle { database_key: query(Id(0)), iteration_count: IterationCount(3) }",
"WillCheckCancellation",
"WillIterateCycle { database_key: query(Id(0)), iteration_count: IterationCount(4) }",
"WillCheckCancellation",
"WillIterateCycle { database_key: query(Id(0)), iteration_count: IterationCount(5) }",
"WillCheckCancellation",
"DidValidateMemoizedValue { database_key: entry(Id(0)) }",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before the fix, this incorrectly re-executed entry but not the cycle (which is what depends on the value).

With the fix in execute_maybe_iterate, it did re-execute the cycle, but it also re-executed entry, which is incorrect because query returns the same value (it can be backdated). Now, the behavior is what we want. entry does not get re-executed but the cycle is

]"#]]);
}