@@ -835,6 +835,7 @@ def _batch_indexer(
835835 np .dtype ('int32' ), new_integer_indexer_shape , 0 )
836836 else :
837837 batch_idx = indexing .Slice (0 , axis_size ) # type: ignore
838+ new_integer_indexer_shape = ()
838839 new_indices .insert (ref_dim , batch_idx )
839840 return indexing .NDIndexer (
840841 tuple (new_indices ), ref_shape , new_integer_indexer_shape , validate = True
@@ -862,38 +863,41 @@ def _get_vmap(batched_args, batched_dims, *, tree):
862863 flat_indexers , tree = tree_util .tree_flatten (new_indexers )
863864
864865 is_int_indexing , _ , _ = indexing .unpack_ndindexer (indexers [0 ])
865- int_indexers_contiguous = bool (
866- np .all (np .diff (np .where (is_int_indexing )[0 ]) == 1 )
867- )
866+ # Note: _batch_indexer will add a slice for the batch dim if the int_indexer
867+ # shape is empty, else it will use advanced/int indexing.
868+ will_add_int_batcher = bool (indexers [0 ].int_indexer_shape )
869+
868870 is_new_int_indexing , _ , _ = indexing .unpack_ndindexer (new_indexers [0 ])
869871 new_int_indexers_contiguous = bool (
870872 np .all (np .diff (np .where (is_new_int_indexing )[0 ]) == 1 )
871873 )
872874
873875 out = get_p .bind (ref , * flat_indexers , tree = tree )
874- if not int_indexers_contiguous : # will always be moved to the front
875- out_bdim = 0
876- else : # originally not going to be moved to the front
877- if new_int_indexers_contiguous : # now not going to be moved to the front
878- try :
879- out_bdim = is_new_int_indexing .index (True )
880- except ValueError :
881- out_bdim = 0
882- else : # now going to be moved to the front
883- original_pos = is_int_indexing .index (True )
884- array_indexer_shape = new_indexers [0 ].int_indexer_shape
885- array_indexer_len = len (array_indexer_shape )
886-
887- transpose_order = list (range (len (out .shape )))
888- transpose_order = (
889- transpose_order [0 ],
890- * transpose_order [array_indexer_len :array_indexer_len + original_pos ],
891- * transpose_order [1 :array_indexer_len ],
892- * transpose_order [array_indexer_len + original_pos :],
893- )
876+ if will_add_int_batcher and not new_int_indexers_contiguous : # now going to be moved to the front
877+ original_pos = is_int_indexing .index (True )
878+ array_indexer_shape = new_indexers [0 ].int_indexer_shape
879+ array_indexer_len = len (array_indexer_shape )
894880
895- out = lax .transpose (out , transpose_order )
881+ transpose_order = list (range (len (out .shape )))
882+ transpose_order = (
883+ transpose_order [0 ],
884+ * transpose_order [array_indexer_len :array_indexer_len + original_pos ],
885+ * transpose_order [1 :array_indexer_len ],
886+ * transpose_order [array_indexer_len + original_pos :],
887+ )
888+ out = lax .transpose (out , transpose_order )
889+ out_bdim = 0
890+ else :
891+ if ref_dim is not batching .not_mapped :
892+ # We only trigger this case when the int_indexer shape is empty,
893+ # so we don't need to account for int_indexer_shape.
894+ int_indexers_before_ref_dim = int (np .sum (is_new_int_indexing [:ref_dim ]))
895+ out_bdim = ref_dim - int_indexers_before_ref_dim
896+ else :
896897 out_bdim = 0
898+ if any (is_int_indexing ):
899+ original_pos = is_int_indexing .index (True )
900+ out_bdim = original_pos
897901 return out , out_bdim
898902batching .primitive_batchers [get_p ] = _get_vmap
899903
0 commit comments