Skip to content

Commit ff48e6f

Browse files
fix NNX tests
1 parent f2c00fe commit ff48e6f

File tree

2 files changed

+55
-2
lines changed

2 files changed

+55
-2
lines changed

keras/src/backend/jax/core.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,59 @@ def value(self):
234234

235235
Variable = NnxVariable
236236

237+
def _flatten_nnx_variable(variable):
238+
children = (variable.raw_value,)
239+
# We copy __dict__ to avoid side effects
240+
keras_state = variable.__dict__.copy()
241+
# Remove elements that might be problematic or redundant if
242+
# nnx.Variable's __getstate__
243+
keras_state.pop("raw_value", None)
244+
aux_data = (
245+
variable._var_metadata,
246+
getattr(variable, "_trace_state", None),
247+
keras_state,
248+
)
249+
return children, aux_data
250+
251+
def _unflatten_nnx_variable(aux_data, children):
252+
var_metadata, trace_state, keras_state = aux_data
253+
raw_value = children[0]
254+
255+
# Create uninitialized instance
256+
variable = NnxVariable.__new__(NnxVariable)
257+
258+
# Restore state
259+
variable._var_metadata = var_metadata
260+
if trace_state is not None:
261+
variable._trace_state = trace_state
262+
variable.__dict__.update(keras_state)
263+
variable.raw_value = raw_value
264+
265+
return variable
266+
267+
try:
268+
jax.tree_util.register_pytree_node(
269+
NnxVariable,
270+
_flatten_nnx_variable,
271+
_unflatten_nnx_variable,
272+
)
273+
except ValueError:
274+
pass
275+
276+
def __setattr__(self, name, value):
277+
# Mirror Keras attributes to _var_metadata to ensure persistence
278+
# if the Pytree registration is not respected by NNX.
279+
if (
280+
name != "_var_metadata"
281+
and name not in ("_raw_value", "_trace_state")
282+
and hasattr(self, "_var_metadata")
283+
):
284+
self._var_metadata[name] = value
285+
286+
object.__setattr__(self, name, value)
287+
288+
NnxVariable.__setattr__ = __setattr__
289+
237290

238291
def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
239292
if ragged:

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ torch-xla==2.6.0;sys_platform != 'darwin'
1313
# Jax.
1414
# Pinned to 0.5.0 on CPU. JAX 0.5.1 requires Tensorflow 2.19 for saved_model_test.
1515
# Note that we test against the latest JAX on GPU.
16-
jax[cpu]==0.5.0
17-
flax
16+
jax[cpu]==0.8.1
17+
flax==0.12.1
1818

1919
# Common deps.
2020
-r requirements-common.txt

0 commit comments

Comments
 (0)