@@ -575,6 +575,8 @@ def compute_shape(
575575 input_size = collections .deque (tensor .size ())
576576 output_size = []
577577 env = CompileEnvironment .current ()
578+ tensor_indexers = [k for k in index if isinstance (k , torch .Tensor )]
579+ should_broadcast = env .should_broadcast_tensor_indexers (tensor_indexers )
578580 k_index = 0
579581 for k in index :
580582 if k is None :
@@ -617,11 +619,14 @@ def compute_shape(
617619 else :
618620 output_size .append (1 )
619621 k_index += 1
620- elif isinstance (k , torch .Tensor ) and (
621- k .ndim == 1 or (len (index ) == 1 and tensor .ndim == 1 )
622- ):
622+ elif isinstance (k , torch .Tensor ):
623623 input_size .popleft ()
624- output_size .extend (k .size ())
624+ if not should_broadcast :
625+ output_size .extend (env .tensor_indexer_dims (k ))
626+ elif k is tensor_indexers [0 ]:
627+ output_size .extend (
628+ env .tensor_indexer_broadcast_shape (tensor_indexers )
629+ )
625630 k_index += 1
626631 else :
627632 raise exc .InvalidIndexingType (k )
@@ -667,13 +672,87 @@ def create(
667672 output_size = SubscriptIndexing .compute_shape (fake_value , index , state )
668673 env = CompileEnvironment .current ()
669674 dtype = env .triton_index_type ()
675+ tensor_indexers = [k for k in index if isinstance (k , torch .Tensor )]
676+ should_broadcast = env .should_broadcast_tensor_indexers (tensor_indexers )
677+ broadcast_dims = 0
678+ if should_broadcast :
679+ broadcast_dims = len (env .tensor_indexer_broadcast_shape (tensor_indexers ))
680+ is_cartesian = (
681+ broadcast_dims >= 2
682+ and len (tensor_indexers ) == broadcast_dims
683+ and all (
684+ t .ndim == 1
685+ or sum (1 for d in t .size () if env .size_hint (d ) != 1 ) <= 1
686+ for t in tensor_indexers
687+ )
688+ )
670689 if dtype == "tl.int32" and SubscriptIndexing ._needs_int64 (fake_value ):
671690 raise exc .IndexOffsetOutOfRangeForInt32 (env .index_dtype )
672691
673692 def _is_size_one (size : int | torch .SymInt ) -> bool :
674693 return env .known_equal (size , 1 )
675694
676695 k_index = 0
696+
697+ def handle_broadcast_tensor (
698+ position : int ,
699+ index_elem : torch .Tensor ,
700+ index_var : str ,
701+ cur_output_idx : int ,
702+ ) -> tuple [str , dict [str , None ]]:
703+ assert broadcast_dims > 0
704+ tensor_idx = next (
705+ i for i , t in enumerate (tensor_indexers ) if t is index_elem
706+ )
707+ first_tensor_out_idx = (
708+ cur_output_idx if tensor_idx == 0 else cur_output_idx - broadcast_dims
709+ )
710+ non_trivial_output_positions : list [int ] = []
711+ if is_cartesian :
712+ pos = first_tensor_out_idx + tensor_idx
713+ single_output_dim = True
714+ else :
715+ # Find position(s) where this tensor contributes non-trivial dims
716+ offset = max (0 , broadcast_dims - index_elem .ndim )
717+ non_trivial_output_positions = [
718+ first_tensor_out_idx + offset + i
719+ for i in range (index_elem .ndim )
720+ if env .size_hint (index_elem .size (i )) != 1
721+ ]
722+ pos = non_trivial_output_positions [0 ]
723+ single_output_dim = len (non_trivial_output_positions ) <= 1
724+
725+ new_masks : dict [str , None ] = {}
726+ if single_output_dim :
727+ expand = (
728+ tile_strategy .expand_str (output_size , pos )
729+ if index_elem .ndim == 1
730+ else ""
731+ )
732+ idx_val = f"({ index_var } ){ expand } "
733+ else :
734+ # Multi-dim tensor with multiple non-trivial dims
735+ idx_val = f"({ index_var } )"
736+ if tensor_idx == 0 :
737+ for p in non_trivial_output_positions :
738+ if (
739+ p < len (output_size )
740+ and (bid := env .get_block_id (output_size [p ]))
741+ and (mv := state .codegen .mask_var (bid ))
742+ and not _is_size_one (fake_value .size (len (index_values )))
743+ ):
744+ new_masks .setdefault (
745+ f"({ mv } ){ tile_strategy .expand_str (output_size , p )} "
746+ )
747+ # Padded iota mask
748+ if (
749+ orig_len := _get_padded_iota_original_length (state , position )
750+ ) is not None :
751+ new_masks .setdefault (
752+ f"(({ index_var } < { orig_len } ){ tile_strategy .expand_str (output_size , first_tensor_out_idx + tensor_idx )} )"
753+ )
754+ return idx_val , new_masks
755+
677756 for n , k in enumerate (index ):
678757 if k is None :
679758 output_idx += 1
@@ -752,40 +831,42 @@ def _is_size_one(size: int | torch.SymInt) -> bool:
752831 index_values .append (f"tl.zeros([1], { dtype } ){ expand } " )
753832 output_idx += 1
754833 k_index += 1
755- elif isinstance (k , torch .Tensor ) and k .ndim == 1 :
756- expand = tile_strategy .expand_str (output_size , output_idx )
834+ elif isinstance (k , torch .Tensor ):
757835 ast_index = state .ast_args [1 ]
758836 assert isinstance (ast_index , (list , tuple ))
759- assert len (ast_index ) == len (index )
760837 index_var = state .codegen .lift (ast_index [n ], prefix = "index" ).id
838+
839+ # Use broadcast handling for: multiple tensors, or single tensor with ndim > 1
840+ if should_broadcast :
841+ idx_val , new_masks = handle_broadcast_tensor (
842+ n , k , index_var , output_idx
843+ )
844+ index_values .append (idx_val )
845+ mask_values .update (new_masks )
846+ if k is tensor_indexers [0 ]:
847+ output_idx += broadcast_dims
848+ k_index += 1
849+ continue
850+
851+ expand = (
852+ tile_strategy .expand_str (output_size , output_idx )
853+ if k .ndim < len (output_size )
854+ else ""
855+ )
761856 index_values .append (f"({ index_var } ){ expand } " )
762- if (block_idx := env .get_block_id (output_size [output_idx ])) is not None :
763- if mask := state .codegen .mask_var (block_idx ):
764- mask_values .setdefault (f"({ mask } ){ expand } " )
765- # Check if this index comes from a padded hl.arange and generate mask
766- if (
767- original_length := _get_padded_iota_original_length (state , n )
768- ) is not None :
769- mask_values .setdefault (f"({ index_var } < { original_length } ){ expand } " )
770- output_idx += 1
771- k_index += 1
772- elif (
773- isinstance (k , torch .Tensor ) and len (index ) == 1 and fake_value .ndim == 1
774- ):
775- # TODO(jansel): combine this case with the above
776- ast_index = state .ast_args [1 ]
777- assert isinstance (ast_index , (list , tuple ))
778- assert len (ast_index ) == 1
779- index_var = state .codegen .lift (ast_index [0 ], prefix = "index" ).id
780- index_values .append (index_var )
781- output_idx += k .ndim
782- for n , s in enumerate (output_size ):
783- if (block_idx := env .get_block_id (s )) is not None and (
784- mask := state .codegen .mask_var (block_idx )
857+ mask_block_id = (
858+ env .get_block_id (output_size [output_idx ])
859+ if output_idx < len (output_size )
860+ else None
861+ )
862+ if mask_block_id is not None :
863+ mask_var = state .codegen .mask_var (mask_block_id )
864+ if mask_var and not _is_size_one (
865+ fake_value .size (len (index_values ) - 1 )
785866 ):
786- mask_values .setdefault (
787- f"( { mask } ) { tile_strategy . expand_str ( output_size , n ) } "
788- )
867+ mask_values .setdefault (f"( { mask_var } ) { expand } " )
868+
869+ output_idx += k . ndim
789870 k_index += 1
790871 else :
791872 raise exc .InvalidIndexingType (type (k ))
0 commit comments