Skip to content

Commit c1188e7

Browse files
Add __getitem__ to backwards compatible shims.
PiperOrigin-RevId: 846196686
1 parent b9b2d4f commit c1188e7

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

jax/_src/interpreters/batching.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,10 @@ def __setitem__(self, prim, batcher):
660660
def wrapped(axis_data, vals, dims, **params):
661661
return batcher(axis_data.size, axis_data.name, None, vals, dims, **params)
662662
fancy_primitive_batchers[prim] = wrapped
663+
664+
def __getitem__(self, prim):
665+
return fancy_primitive_batchers[prim]
666+
663667
axis_primitive_batchers = AxisPrimitiveBatchersProxy()
664668

665669
# backwards compat shim. TODO: delete
@@ -675,6 +679,10 @@ def wrapped(axis_data, vals, dims, **params):
675679

676680
def __delitem__(self, prim):
677681
del fancy_primitive_batchers[prim]
682+
683+
def __getitem__(self, prim):
684+
return fancy_primitive_batchers[prim]
685+
678686
primitive_batchers = PrimitiveBatchersProxy()
679687

680688

0 commit comments

Comments
 (0)