fix(graph): merge_context index_ref, check_valid_context, pop_states aliasing#212
Conversation
…aliasing Audit of brainstate.graph found three correctness bugs: - merge_context yielded `dict(index_ref)` — an empty snapshot disconnected from the table that `treefy_merge` populates. Yield the live dict so it is symmetric with `split_context` (whose live `ref_index` is consumed by `graph_to_tree`). - Node.check_valid_context read `self._trace_state`, which graph nodes never carry (only `State` objects do), so it raised `AttributeError` for every graph node (e.g. `nn.Linear`). A node has no trace state of its own; its validity is the conjunction of the trace validity of the `State`s reachable from it — iterate them and raise `TraceContextError` on any invalid one. - pop_states deduplicated a matched `State` by identity and popped only its first reference; later shared/tied aliases were skipped and left dangling on the node (e.g. `baz.weight = bar.weight` then popping ParamState removed `bar.weight` but kept `baz.weight`). Detach every alias of a popped state while still recording it once. Adds regression tests for each. graph (192), transform (1136) and nn (1773) suites pass.
Reviewer's GuideFixes three graph correctness issues: merge_context now yields the live index_ref table instead of a snapshot, Node.check_valid_context validates reachable State objects instead of accessing a nonexistent _trace_state, and _graph_pop/pop_states now fully detaches all aliases of popped states while still deduplicating them by identity; all covered by new regression tests. Sequence diagram for merge_context yielding live index_refsequenceDiagram
participant Caller
participant merge_context
participant GRAPH_CONTEXT
Caller->>merge_context: enter merge_context()
merge_context->>GRAPH_CONTEXT: index_ref_stack.append(unflatten_ctx)
merge_context-->>Caller: yield unflatten_ctx, index_ref
Caller->>index_ref: update index_to_object_mappings
Caller-->>merge_context: exit context
merge_context->>GRAPH_CONTEXT: index_ref_stack.pop()
merge_context->>unflatten_ctx: del index_ref
Sequence diagram for _graph_pop detaching shared State aliasessequenceDiagram
participant Caller
participant _graph_pop
participant impl
participant graph_node
Caller->>_graph_pop: _graph_pop(graph_node, id_to_index,...)
_graph_pop->>graph_node: traverse children
_graph_pop->>graph_node: visit first State_leaf
_graph_pop->>id_to_index: add id(State)
_graph_pop->>impl: pop_key(graph_node, name_first)
_graph_pop->>graph_node: visit second alias of same State
_graph_pop->>id_to_index: check id(State) in id_to_index
alt [id(State) already in id_to_index]
_graph_pop->>impl: pop_key(graph_node, name_alias)
_graph_pop-->>Caller: do not record State again
end
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
There was a problem hiding this comment.
Hey - I've found 1 issue
Prompt for AI Agents
Please address the comments from this code review:
## Individual Comments
### Comment 1
<location path="brainstate/graph/_context_test.py" line_range="105-114" />
<code_context>
self.assertIsNotNone(mctx)
self.assertFalse(hasattr(mctx, 'index_ref'))
+ def test_merge_context_exposes_live_index_ref(self):
+ """``merge_context`` must yield the *live* index_ref table.
+
+ Regression: it previously yielded ``dict(unflatten_ctx.index_ref)`` — a
+ snapshot taken at yield time (empty) and disconnected from the table that
+ ``treefy_merge`` actually populates. This is now symmetric with
+ ``split_context``, which yields its live ``RefMap``.
+ """
+ m = brainstate.nn.Linear(2, 3)
+ graphdef, state = brainstate.graph.treefy_split(m)
+ with brainstate.graph.merge_context() as (ctx, index_ref):
+ rebuilt = ctx.treefy_merge(graphdef, state)
+ # The yielded table is populated and contains the rebuilt root object.
+ self.assertGreater(len(index_ref), 0)
+ self.assertIn(id(rebuilt), {id(v) for v in index_ref.values()})
+
</code_context>
<issue_to_address>
**suggestion (testing):** Strengthen `merge_context` regression test by asserting that the yielded `index_ref` is the live table, not just populated.
The current assertions only show that `index_ref` is populated with the rebuilt root, but not that it’s the same live mapping held by `MergeContext` rather than a filled copy. To better protect against regressions to the old snapshot behavior, add an identity-style assertion (e.g. `self.assertIs(ctx.index_ref, index_ref)` if applicable) or another invariant that specifically distinguishes a live reference from a snapshot.
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
| def test_merge_context_exposes_live_index_ref(self): | ||
| """``merge_context`` must yield the *live* index_ref table. | ||
|
|
||
| Regression: it previously yielded ``dict(unflatten_ctx.index_ref)`` — a | ||
| snapshot taken at yield time (empty) and disconnected from the table that | ||
| ``treefy_merge`` actually populates. This is now symmetric with | ||
| ``split_context``, which yields its live ``RefMap``. | ||
| """ | ||
| m = brainstate.nn.Linear(2, 3) | ||
| graphdef, state = brainstate.graph.treefy_split(m) |
There was a problem hiding this comment.
suggestion (testing): Strengthen merge_context regression test by asserting that the yielded index_ref is the live table, not just populated.
The current assertions only show that index_ref is populated with the rebuilt root, but not that it’s the same live mapping held by MergeContext rather than a filled copy. To better protect against regressions to the old snapshot behavior, add an identity-style assertion (e.g. self.assertIs(ctx.index_ref, index_ref) if applicable) or another invariant that specifically distinguishes a live reference from a snapshot.
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
Summary
Audit of
brainstate.graphsurfaced three correctness bugs, each reproduced with a regression test before fixing.BUG 1 —
merge_contextyields a disconnected, emptyindex_ref_context.py::merge_contextyieldeddict(unflatten_ctx.index_ref)— a snapshot taken at yield time (empty) and disconnected from the tabletreefy_mergeactually populates. Asymmetric withsplit_context, which yields its liveRefMap(consumed bygraph_to_treeafter the block). Fix: yield the live dict.BUG 2 —
Node.check_valid_contextraisesAttributeError_node.py::Node.check_valid_contextreadself._trace_state, but graph nodes (incl.nn.Linear,nn.Module) carry no_trace_state— onlyStateobjects do — so it raisedAttributeErrorfor every node. (Currently only reached as dead code viacheck_consistent_aliasing, but it is broken public API.) Fix: a node's validity is the conjunction of the trace validity of theStates reachable from it; iterate them and raiseTraceContextErroron any invalid one.BUG 3 —
pop_statesleaves dangling aliases for shared/tied states_operations.py::_graph_popdeduplicated a matchedStateby identity and popped only its first reference; later aliases were skipped and survived. After popping a tied weight (baz.weight = bar.weight),bar.weightwas removed butbaz.weightremained, leaving the node half-mutated. Fix: detach every alias of a popped state, while still recording it once.Withdrawn after investigation
A suspected
KeyError-vs-ValueErrorcontract gap on missingStateLeafEdgepaths turned out to be unreachable:TreefyStateis itself a JAX pytree, soStateLeafEdgeis never produced byflatten(dead code). No change.Testing
brainstate/graph/— 192 passedbrainstate/transform/— 1136 passedbrainstate/nn/— 1773 passed (14 pre-existing unrelated skips)🤖 Generated with Claude Code
Summary by Sourcery
Fix multiple graph correctness issues and add regression coverage for context merging, node trace validation, and popping shared states.
Bug Fixes:
Tests: