Skip to content

Commit b42c930

Browse files
justinjfuGoogle-ML-Automation
authored andcommitted
Fix get vmap
PiperOrigin-RevId: 833888270
1 parent 2611fd2 commit b42c930

File tree

2 files changed

+52
-24
lines changed

2 files changed

+52
-24
lines changed

jax/_src/state/primitives.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,7 @@ def _batch_indexer(
835835
np.dtype('int32'), new_integer_indexer_shape, 0)
836836
else:
837837
batch_idx = indexing.Slice(0, axis_size) # type: ignore
838+
new_integer_indexer_shape = ()
838839
new_indices.insert(ref_dim, batch_idx)
839840
return indexing.NDIndexer(
840841
tuple(new_indices), ref_shape, new_integer_indexer_shape, validate=True
@@ -862,38 +863,41 @@ def _get_vmap(batched_args, batched_dims, *, tree):
862863
flat_indexers, tree = tree_util.tree_flatten(new_indexers)
863864

864865
is_int_indexing, _, _ = indexing.unpack_ndindexer(indexers[0])
865-
int_indexers_contiguous = bool(
866-
np.all(np.diff(np.where(is_int_indexing)[0]) == 1)
867-
)
866+
# Note: _batch_indexer will add a slice for the batch dim if the int_indexer
867+
# shape is empty, else it will use advanced/int indexing.
868+
will_add_int_batcher = bool(indexers[0].int_indexer_shape)
869+
868870
is_new_int_indexing, _, _ = indexing.unpack_ndindexer(new_indexers[0])
869871
new_int_indexers_contiguous = bool(
870872
np.all(np.diff(np.where(is_new_int_indexing)[0]) == 1)
871873
)
872874

873875
out = get_p.bind(ref, *flat_indexers, tree=tree)
874-
if not int_indexers_contiguous: # will always be moved to the front
875-
out_bdim = 0
876-
else: # originally not going to be moved to the front
877-
if new_int_indexers_contiguous: # now not going to be moved to the front
878-
try:
879-
out_bdim = is_new_int_indexing.index(True)
880-
except ValueError:
881-
out_bdim = 0
882-
else: # now going to be moved to the front
883-
original_pos = is_int_indexing.index(True)
884-
array_indexer_shape = new_indexers[0].int_indexer_shape
885-
array_indexer_len = len(array_indexer_shape)
886-
887-
transpose_order = list(range(len(out.shape)))
888-
transpose_order = (
889-
transpose_order[0],
890-
*transpose_order[array_indexer_len:array_indexer_len+original_pos],
891-
*transpose_order[1:array_indexer_len],
892-
*transpose_order[array_indexer_len+original_pos:],
893-
)
876+
if will_add_int_batcher and not new_int_indexers_contiguous: # now going to be moved to the front
877+
original_pos = is_int_indexing.index(True)
878+
array_indexer_shape = new_indexers[0].int_indexer_shape
879+
array_indexer_len = len(array_indexer_shape)
894880

895-
out = lax.transpose(out, transpose_order)
881+
transpose_order = list(range(len(out.shape)))
882+
transpose_order = (
883+
transpose_order[0],
884+
*transpose_order[array_indexer_len:array_indexer_len+original_pos],
885+
*transpose_order[1:array_indexer_len],
886+
*transpose_order[array_indexer_len+original_pos:],
887+
)
888+
out = lax.transpose(out, transpose_order)
889+
out_bdim = 0
890+
else:
891+
if ref_dim is not batching.not_mapped:
892+
# We only trigger this case when the int_indexer shape is empty,
893+
# so we don't need to account for int_indexer_shape.
894+
int_indexers_before_ref_dim = int(np.sum(is_new_int_indexing[:ref_dim]))
895+
out_bdim = ref_dim - int_indexers_before_ref_dim
896+
else:
896897
out_bdim = 0
898+
if any(is_int_indexing):
899+
original_pos = is_int_indexing.index(True)
900+
out_bdim = original_pos
897901
return out, out_bdim
898902
batching.primitive_batchers[get_p] = _get_vmap
899903

tests/state_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,30 @@ def f(x_ref):
136136
self.assertEqual(out_aval.shape, out_shape)
137137
self.assertEqual(out_aval.dtype, out_dtype)
138138

139+
@parameterized.parameters(
140+
((4, 5), 0, (0,)),
141+
((4, 5), 1, (0,)),
142+
((9, 10, 11, 12), 0, (slice(None), 0, 1)), # Contiguous int indexing
143+
((9, 10, 11, 12), 0, (0, slice(None), 1)), # Non-contiguous int indexing
144+
((9, 10, 11, 12), 1, (slice(None), 0, 1)), # Contiguous after batch
145+
((9, 10, 11, 12), 2, (slice(None), 0, 1)), # Non-contiguous after batch
146+
((9, 10, 11, 12), 3, (slice(None), slice(None), 0)),
147+
# Shaped int indexer, contiguous after batch
148+
((9, 10, 11, 12), 3,
149+
(slice(None), slice(None), np.array([[0,1]]))),
150+
# Shaped int indexer, non-contiguous after batch
151+
((9, 10, 11, 12), 2,
152+
(np.array([[0, 1]]), slice(None), np.array([[0, 1]]))),
153+
)
154+
def test_vmap_of_get_regression(self, shape, in_axes, indexer):
155+
# Regression test for https://github.com/jax-ml/jax/issues/33309
156+
def f(x):
157+
return x[indexer]
158+
x = jnp.ones(shape)
159+
result = jax.vmap(f, in_axes=in_axes)(jax.new_ref(x))
160+
expected = jax.vmap(f, in_axes=in_axes)(x)
161+
self.assertArraysEqual(result, expected)
162+
139163
def test_swap_abstract_eval_must_take_in_refs(self):
140164
ref_aval = core.ShapedArray((), jnp.float32)
141165
val_aval = core.ShapedArray((), jnp.float32)

0 commit comments

Comments
 (0)