@@ -234,6 +234,59 @@ def value(self):
234234
235235 Variable = NnxVariable
236236
237+ def _flatten_nnx_variable (variable ):
238+ children = (variable .raw_value ,)
239+ # We copy __dict__ to avoid side effects
240+ keras_state = variable .__dict__ .copy ()
241+ # Remove elements that might be problematic or redundant if
242+ # nnx.Variable's __getstate__
243+ keras_state .pop ("raw_value" , None )
244+ aux_data = (
245+ variable ._var_metadata ,
246+ getattr (variable , "_trace_state" , None ),
247+ keras_state ,
248+ )
249+ return children , aux_data
250+
251+ def _unflatten_nnx_variable (aux_data , children ):
252+ var_metadata , trace_state , keras_state = aux_data
253+ raw_value = children [0 ]
254+
255+ # Create uninitialized instance
256+ variable = NnxVariable .__new__ (NnxVariable )
257+
258+ # Restore state
259+ variable ._var_metadata = var_metadata
260+ if trace_state is not None :
261+ variable ._trace_state = trace_state
262+ variable .__dict__ .update (keras_state )
263+ variable .raw_value = raw_value
264+
265+ return variable
266+
267+ try :
268+ jax .tree_util .register_pytree_node (
269+ NnxVariable ,
270+ _flatten_nnx_variable ,
271+ _unflatten_nnx_variable ,
272+ )
273+ except ValueError :
274+ pass
275+
276+ def __setattr__ (self , name , value ):
277+ # Mirror Keras attributes to _var_metadata to ensure persistence
278+ # if the Pytree registration is not respected by NNX.
279+ if (
280+ name != "_var_metadata"
281+ and name not in ("_raw_value" , "_trace_state" )
282+ and hasattr (self , "_var_metadata" )
283+ ):
284+ self ._var_metadata [name ] = value
285+
286+ object .__setattr__ (self , name , value )
287+
288+ NnxVariable .__setattr__ = __setattr__
289+
237290
238291def convert_to_tensor (x , dtype = None , sparse = None , ragged = None ):
239292 if ragged :
0 commit comments