Skip to content

Commit 3459817

Browse files
committed
Fix torch keyword arguments in remat.
#21861
1 parent f2c00fe commit 3459817

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

keras/src/backend/common/remat_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,31 @@ def test_remat_basic_call(self):
116116
batch_size=batch_size,
117117
verbose=0,
118118
)
119+
120+
def test_remat_with_kwargs(self):
121+
if backend.backend() in ("openvino", "numpy"):
122+
self.skipTest(
123+
"remat is not supported in openvino and numpy backends."
124+
)
125+
126+
# Define a function that uses keyword arguments
127+
def fn_with_kwargs(x, scale=1.0, offset=0.0):
128+
return x * scale + offset
129+
130+
x = np.array([1.0, 2.0, 3.0], dtype=np.float32)
131+
132+
# Test with keyword arguments
133+
remat_fn = backend.core.remat(fn_with_kwargs)
134+
result_with_kwargs = remat_fn(x, scale=2.0, offset=1.0)
135+
expected = fn_with_kwargs(x, scale=2.0, offset=1.0)
136+
self.assertAllClose(result_with_kwargs, expected)
137+
138+
# Test with default keyword arguments
139+
result_with_defaults = remat_fn(x)
140+
expected_defaults = fn_with_kwargs(x)
141+
self.assertAllClose(result_with_defaults, expected_defaults)
142+
143+
# Test with partial keyword arguments
144+
result_partial = remat_fn(x, scale=3.0)
145+
expected_partial = fn_with_kwargs(x, scale=3.0)
146+
self.assertAllClose(result_partial, expected_partial)

keras/src/backend/torch/core.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,9 @@ def remat(f):
673673
"""
674674

675675
def wrapped(*args, **kwargs):
676-
return torch.utils.checkpoint.checkpoint(f, *args, use_reentrant=False)
676+
return torch.utils.checkpoint.checkpoint(
677+
f, *args, use_reentrant=False, **kwargs
678+
)
677679

678680
return wrapped
679681

0 commit comments

Comments
 (0)