Skip to content

Commit 7065e4f

Browse files
Merge pull request #33217 from gspschmid:gschmid/remat-ob-fix-consts
PiperOrigin-RevId: 831470125
2 parents 3075641 + 6a0415a commit 7065e4f

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

jax/_src/ad_checkpoint.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,8 @@ def fun_remat(*args, **kwargs):
363363
in_avals = [core.shaped_abstractify(x) for x in args_flat]
364364
jaxpr, consts, out_tree = _trace_to_jaxpr(fun_, in_tree, tuple(in_avals), debug)
365365
if isinstance(prevent_cse, tuple):
366-
cse = (*broadcast_prefix(prevent_cse, (args, kwargs) if kwargs else args),)
366+
cse_args = (tuple(args), kwargs) if kwargs else tuple(args)
367+
cse = (False,) * len(consts) + tuple(broadcast_prefix(prevent_cse, cse_args))
367368
else:
368369
cse = prevent_cse
369370
out_flat = remat_p.bind(

tests/api_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7117,7 +7117,9 @@ def make_weight(i):
71177117
def test_remat_partial_cse_prevention(self):
71187118
@partial(jax.remat, prevent_cse=(False, True))
71197119
def layer(W, x):
7120-
return x @ W
7120+
res = x @ W
7121+
res += jnp.array([1.0, 2.0, 3.0]) # ensure the jaxpr also contains a const
7122+
return res
71217123

71227124
def net(Ws, x):
71237125
for W in Ws:

0 commit comments

Comments
 (0)