@@ -835,6 +835,7 @@ def _batch_indexer(
835835 np .dtype ('int32' ), new_integer_indexer_shape , 0 )
836836 else :
837837 batch_idx = indexing .Slice (0 , axis_size ) # type: ignore
838+ new_integer_indexer_shape = ()
838839 new_indices .insert (ref_dim , batch_idx )
839840 return indexing .NDIndexer (
840841 tuple (new_indices ), ref_shape , new_integer_indexer_shape , validate = True
@@ -865,35 +866,52 @@ def _get_vmap(batched_args, batched_dims, *, tree):
865866 int_indexers_contiguous = bool (
866867 np .all (np .diff (np .where (is_int_indexing )[0 ]) == 1 )
867868 )
869+ # Note: _batch_indexer will add a slice for the batch dim if the int_indexer
870+ # shape is empty, else it will use advanced/int indexing.
871+ will_add_int_batcher = bool (indexers [0 ].int_indexer_shape )
872+
868873 is_new_int_indexing , _ , _ = indexing .unpack_ndindexer (new_indexers [0 ])
869874 new_int_indexers_contiguous = bool (
870875 np .all (np .diff (np .where (is_new_int_indexing )[0 ]) == 1 )
871876 )
872877
873878 out = get_p .bind (ref , * flat_indexers , tree = tree )
874- if not int_indexers_contiguous : # will always be moved to the front
875- out_bdim = 0
876- else : # originally not going to be moved to the front
877- if new_int_indexers_contiguous : # now not going to be moved to the front
878- try :
879- out_bdim = is_new_int_indexing .index (True )
880- except ValueError :
881- out_bdim = 0
882- else : # now going to be moved to the front
883- original_pos = is_int_indexing .index (True )
884- array_indexer_shape = new_indexers [0 ].int_indexer_shape
885- array_indexer_len = len (array_indexer_shape )
886-
887- transpose_order = list (range (len (out .shape )))
888- transpose_order = (
889- transpose_order [0 ],
890- * transpose_order [array_indexer_len :array_indexer_len + original_pos ],
891- * transpose_order [1 :array_indexer_len ],
892- * transpose_order [array_indexer_len + original_pos :],
893- )
879+ should_transpose = (int_indexers_contiguous and
880+ not new_int_indexers_contiguous )
881+ if will_add_int_batcher and should_transpose :
882+ original_pos = is_int_indexing .index (True )
883+ array_indexer_shape = new_indexers [0 ].int_indexer_shape
884+ array_indexer_len = len (array_indexer_shape )
894885
895- out = lax .transpose (out , transpose_order )
886+ transpose_order = list (range (len (out .shape )))
887+ transpose_order = (
888+ transpose_order [0 ],
889+ * transpose_order [array_indexer_len :array_indexer_len + original_pos ],
890+ * transpose_order [1 :array_indexer_len ],
891+ * transpose_order [array_indexer_len + original_pos :],
892+ )
893+ out = lax .transpose (out , transpose_order )
894+ out_bdim = 0
895+ else :
896+ if ref_dim is not batching .not_mapped :
897+ if will_add_int_batcher :
898+ if not int_indexers_contiguous :
899+ # In this case the indexer is always moved to the front.
900+ out_bdim = 0
901+ else :
902+ # In this case the indexer is not moved to the front.
903+ out_bdim = is_new_int_indexing .index (True )
904+ else :
905+ # We only trigger this case when the int_indexer shape is empty,
906+ # so we don't need to account for int_indexer_shape.
907+ int_indexers_before_ref_dim = int (np .sum (is_new_int_indexing [:ref_dim ]))
908+ out_bdim = ref_dim - int_indexers_before_ref_dim
909+ else :
896910 out_bdim = 0
911+ if any (is_int_indexing ):
912+ # The batch dim is the indexer's batch dim.
913+ original_pos = is_int_indexing .index (True )
914+ out_bdim = original_pos
897915 return out , out_bdim
898916batching .primitive_batchers [get_p ] = _get_vmap
899917
0 commit comments