Skip to content

nnx.scan sow behavior changes when an nnx.Pytree is used as an explicit nnx.Carry #5051

@maxencefaldor

Description

@maxencefaldor

I've encountered some unexpected behavior with nnx.scan when using nnx.Module.sow.

When nnx.scan is called on an nnx.Module that operates on a standard JAX pytree (like a flax.struct.dataclass) as its carry, the sow functionality works as expected. The nnx.StateAxes({nnx.Intermediate: 0, ...}) config is "smartly" interpreted as an output-only instruction to stack the sowed variables.

However, if that carry object is changed to be an nnx.Pytree, this "smart" logic breaks. nnx.scan seems to fall back to a "strict" interpretation of in_axes, and now requires a matching nnx.Intermediate variable on input, which causes a TypeError (length/shape mismatch).

This seems counter-intuitive, as the carry object is explicitly passed with nnx.Carry and should (in theory) be treated as a standard JAX carry, not interfere with the state management of the nnx.Module being scanned.

Minimal Reproduction Code

The following code demonstrates the issue. The only difference between the two tests is the class definition of the state object (StateAsDataclass vs. StateAsPytree).

from functools import partial

import jax
import jax.numpy as jnp
from flax import nnx, struct


# Case 1: Standard JAX Pytree (flax.struct.dataclass)
@struct.dataclass
class StateAsDataclass:
	data: jax.Array


# Case 2: NNX Pytree
class StateAsPytree(nnx.Pytree):
	def __init__(self, data: jax.Array):
		self.data = data


class Model(nnx.Module):
	def _step(self, state, sow: bool):
		new_data = state.data + 1

		if sow:
			self.sow(nnx.Intermediate, "data", new_data)

		# Return the next carry, which must have the same type
		if isinstance(state, StateAsDataclass):
			return state.replace(data=new_data)
		else:
			state.data = new_data
			return state

	@partial(nnx.jit, static_argnames=("num_steps", "sow"))
	def __call__(self, state, *, num_steps: int, sow: bool = False):
		state_axes = nnx.StateAxes({nnx.Intermediate: 0, ...: nnx.Carry})
		state_final = nnx.scan(
			partial(Model._step, sow=sow),
			in_axes=(state_axes, nnx.Carry), 
			out_axes=nnx.Carry,
			length=num_steps,
		)(self, state)

		return state_final


NUM_STEPS = 5

# --- Test Case 1: StateAsDataclass (This works) ---
print("--- Testing with StateAsDataclass ---")
model = Model()
state = StateAsDataclass(data=jnp.array(0.0))

state_final = model(state, num_steps=NUM_STEPS, sow=True)
intermediates = nnx.pop(model, nnx.Intermediate)

print(f"Final data: {state_final.data}")
print(f"Sowed data shape: {intermediates['data'].value[0].shape}")
assert intermediates["data"].value[0].shape == (NUM_STEPS,)
print("SUCCESS: `sow` worked as expected.\n")


# --- Test Case 2: StateAsPytree (This fails) ---
print("--- Testing with StateAsPytree ---")
model = Model()
state = StateAsPytree(data=jnp.array(0.0))

state_final = model(state, num_steps=NUM_STEPS, sow=True)
intermdiates = nnx.pop(model, nnx.Intermediate)

print(f"Final data: {state_final.data}")
print(f"Sowed data shape: {intermediates['data'].value[0].shape}")
assert intermediates["data"].value[0].shape == (NUM_STEPS,)
print("SUCCESS\n")

Is this the intended behavior? It seems confusing that simply changing the type of an explicit, non-managed carry object would break the sow logic for the managed module.

System information

  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib:
Name: flax
Version: 0.12.0
Location: /home/projects/cax/.venv/lib/python3.13/site-packages
Requires: jax, msgpack, numpy, optax, orbax-checkpoint, pyyaml, rich, tensorstore, treescope, typing-extensions
Required-by: cax, evosax
---
Name: jax
Version: 0.8.0
Location: /home/projects/cax/.venv/lib/python3.13/site-packages
Requires: jaxlib, ml-dtypes, numpy, opt-einsum, scipy
Required-by: cax, chex, evosax, flax, optax, orbax-checkpoint
---
Name: jaxlib
Version: 0.8.0
Location: /home/projects/cax/.venv/lib/python3.13/site-packages
Requires: ml-dtypes, numpy, scipy
Required-by: chex, jax, optax
  • Python version: 3.13.3
  • GPU/TPU model and memory: NVIDIA H100 80GB HBM3

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions