@@ -116,3 +116,31 @@ def test_remat_basic_call(self):
116116 batch_size = batch_size ,
117117 verbose = 0 ,
118118 )
119+
120+ def test_remat_with_kwargs (self ):
121+ if backend .backend () in ("openvino" , "numpy" ):
122+ self .skipTest (
123+ "remat is not supported in openvino and numpy backends."
124+ )
125+
126+ # Define a function that uses keyword arguments
127+ def fn_with_kwargs (x , scale = 1.0 , offset = 0.0 ):
128+ return x * scale + offset
129+
130+ x = np .array ([1.0 , 2.0 , 3.0 ], dtype = np .float32 )
131+
132+ # Test with keyword arguments
133+ remat_fn = backend .core .remat (fn_with_kwargs )
134+ result_with_kwargs = remat_fn (x , scale = 2.0 , offset = 1.0 )
135+ expected = fn_with_kwargs (x , scale = 2.0 , offset = 1.0 )
136+ self .assertAllClose (result_with_kwargs , expected )
137+
138+ # Test with default keyword arguments
139+ result_with_defaults = remat_fn (x )
140+ expected_defaults = fn_with_kwargs (x )
141+ self .assertAllClose (result_with_defaults , expected_defaults )
142+
143+ # Test with partial keyword arguments
144+ result_partial = remat_fn (x , scale = 3.0 )
145+ expected_partial = fn_with_kwargs (x , scale = 3.0 )
146+ self .assertAllClose (result_partial , expected_partial )
0 commit comments