File tree Expand file tree Collapse file tree 1 file changed +14
-2
lines changed
Expand file tree Collapse file tree 1 file changed +14
-2
lines changed Original file line number Diff line number Diff line change @@ -266,8 +266,14 @@ def make_train_function(self, force=False):
266266 if distribution_lib .distribution () is not None :
267267 state_shardings = self ._get_state_sharding_spec ()
268268 out_shardings = (None , state_shardings )
269+ if is_nnx_enabled ():
270+ step_fn = lambda state , data : type (self ).train_step (
271+ self , state , data
272+ )
273+ else :
274+ step_fn = self .train_step
269275 train_step = jit (
270- self . train_step ,
276+ step_fn ,
271277 donate_argnums = 0 ,
272278 out_shardings = out_shardings ,
273279 )
@@ -296,8 +302,14 @@ def make_test_function(self, force=False):
296302 metrics_shardings ,
297303 )
298304 out_shardings = (None , state_shardings )
305+ if is_nnx_enabled ():
306+ step_fn = lambda state , data : type (self ).test_step (
307+ self , state , data
308+ )
309+ else :
310+ step_fn = self .test_step
299311 test_step = jit (
300- self . test_step ,
312+ step_fn ,
301313 donate_argnums = 0 ,
302314 out_shardings = out_shardings ,
303315 )
You can’t perform that action at this time.
0 commit comments