-
Notifications
You must be signed in to change notification settings - Fork 757
Description
I've encountered some unexpected behavior with nnx.scan when using nnx.Module.sow.
When nnx.scan is called on an nnx.Module that operates on a standard JAX pytree (like a flax.struct.dataclass) as its carry, the sow functionality works as expected. The nnx.StateAxes({nnx.Intermediate: 0, ...}) config is "smartly" interpreted as an output-only instruction to stack the sowed variables.
However, if that carry object is changed to be an nnx.Pytree, this "smart" logic breaks. nnx.scan seems to fall back to a "strict" interpretation of in_axes, and now requires a matching nnx.Intermediate variable on input, which causes a TypeError (length/shape mismatch).
This seems counter-intuitive, as the carry object is explicitly passed with nnx.Carry and should (in theory) be treated as a standard JAX carry, not interfere with the state management of the nnx.Module being scanned.
Minimal Reproduction Code
The following code demonstrates the issue. The only difference between the two tests is the class definition of the state object (StateAsDataclass vs. StateAsPytree).
from functools import partial
import jax
import jax.numpy as jnp
from flax import nnx, struct
# Case 1: Standard JAX Pytree (flax.struct.dataclass)
@struct.dataclass
class StateAsDataclass:
data: jax.Array
# Case 2: NNX Pytree
class StateAsPytree(nnx.Pytree):
def __init__(self, data: jax.Array):
self.data = data
class Model(nnx.Module):
def _step(self, state, sow: bool):
new_data = state.data + 1
if sow:
self.sow(nnx.Intermediate, "data", new_data)
# Return the next carry, which must have the same type
if isinstance(state, StateAsDataclass):
return state.replace(data=new_data)
else:
state.data = new_data
return state
@partial(nnx.jit, static_argnames=("num_steps", "sow"))
def __call__(self, state, *, num_steps: int, sow: bool = False):
state_axes = nnx.StateAxes({nnx.Intermediate: 0, ...: nnx.Carry})
state_final = nnx.scan(
partial(Model._step, sow=sow),
in_axes=(state_axes, nnx.Carry),
out_axes=nnx.Carry,
length=num_steps,
)(self, state)
return state_final
NUM_STEPS = 5
# --- Test Case 1: StateAsDataclass (This works) ---
print("--- Testing with StateAsDataclass ---")
model = Model()
state = StateAsDataclass(data=jnp.array(0.0))
state_final = model(state, num_steps=NUM_STEPS, sow=True)
intermediates = nnx.pop(model, nnx.Intermediate)
print(f"Final data: {state_final.data}")
print(f"Sowed data shape: {intermediates['data'].value[0].shape}")
assert intermediates["data"].value[0].shape == (NUM_STEPS,)
print("SUCCESS: `sow` worked as expected.\n")
# --- Test Case 2: StateAsPytree (This fails) ---
print("--- Testing with StateAsPytree ---")
model = Model()
state = StateAsPytree(data=jnp.array(0.0))
state_final = model(state, num_steps=NUM_STEPS, sow=True)
intermdiates = nnx.pop(model, nnx.Intermediate)
print(f"Final data: {state_final.data}")
print(f"Sowed data shape: {intermediates['data'].value[0].shape}")
assert intermediates["data"].value[0].shape == (NUM_STEPS,)
print("SUCCESS\n")Is this the intended behavior? It seems confusing that simply changing the type of an explicit, non-managed carry object would break the sow logic for the managed module.
System information
- Flax, jax, jaxlib versions (obtain with
pip show flax jax jaxlib:
Name: flax
Version: 0.12.0
Location: /home/projects/cax/.venv/lib/python3.13/site-packages
Requires: jax, msgpack, numpy, optax, orbax-checkpoint, pyyaml, rich, tensorstore, treescope, typing-extensions
Required-by: cax, evosax
---
Name: jax
Version: 0.8.0
Location: /home/projects/cax/.venv/lib/python3.13/site-packages
Requires: jaxlib, ml-dtypes, numpy, opt-einsum, scipy
Required-by: cax, chex, evosax, flax, optax, orbax-checkpoint
---
Name: jaxlib
Version: 0.8.0
Location: /home/projects/cax/.venv/lib/python3.13/site-packages
Requires: ml-dtypes, numpy, scipy
Required-by: chex, jax, optax- Python version: 3.13.3
- GPU/TPU model and memory: NVIDIA H100 80GB HBM3