Skip to content

Commit f2c00fe

Browse files
authored
Fix test failures when nnx is enabled (#21875)
* Fix test failures when nnx is enabled * refactoring by gemini review
1 parent cc90ffd commit f2c00fe

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

keras/src/backend/jax/trainer.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,14 @@ def make_train_function(self, force=False):
266266
if distribution_lib.distribution() is not None:
267267
state_shardings = self._get_state_sharding_spec()
268268
out_shardings = (None, state_shardings)
269+
if is_nnx_enabled():
270+
step_fn = lambda state, data: type(self).train_step(
271+
self, state, data
272+
)
273+
else:
274+
step_fn = self.train_step
269275
train_step = jit(
270-
self.train_step,
276+
step_fn,
271277
donate_argnums=0,
272278
out_shardings=out_shardings,
273279
)
@@ -296,8 +302,14 @@ def make_test_function(self, force=False):
296302
metrics_shardings,
297303
)
298304
out_shardings = (None, state_shardings)
305+
if is_nnx_enabled():
306+
step_fn = lambda state, data: type(self).test_step(
307+
self, state, data
308+
)
309+
else:
310+
step_fn = self.test_step
299311
test_step = jit(
300-
self.test_step,
312+
step_fn,
301313
donate_argnums=0,
302314
out_shardings=out_shardings,
303315
)

0 commit comments

Comments
 (0)