Skip to content

Commit 2040612

Browse files
Propagate effects in the abstract eval rule for custom_vmap_p.
PiperOrigin-RevId: 845749688
1 parent 7f7b35e commit 2040612

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

jax/_src/custom_batching.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,8 @@ def custom_vmap_batching(args_flat, dims, *, call, rule, in_tree, out_tree):
260260

261261

262262
def custom_vmap_abstract_eval(*in_avals, call, **_):
263-
return call.out_avals
263+
del in_avals
264+
return call.out_avals, call.effects
264265

265266

266267
def custom_vmap_jvp(primals, tangents, *,
@@ -347,7 +348,7 @@ def to_vmap_over_extra_batched_dims(primals, tangents):
347348
custom_vmap_p = core.Primitive('custom_vmap_call')
348349
custom_vmap_p.multiple_results = True
349350
custom_vmap_p.def_impl(custom_vmap_impl)
350-
custom_vmap_p.def_abstract_eval(custom_vmap_abstract_eval)
351+
custom_vmap_p.def_effectful_abstract_eval(custom_vmap_abstract_eval)
351352
batching.primitive_batchers[custom_vmap_p] = custom_vmap_batching
352353
ad.primitive_jvps[custom_vmap_p] = custom_vmap_jvp
353354
pxla.register_initial_style_primitive(custom_vmap_p)

0 commit comments

Comments
 (0)