Skip to content

Commit 1f880ea

Browse files
authored
Add 2d and 3d indirect indexing support (#593)
1 parent 9d0b8bd commit 1f880ea

File tree

9 files changed

+871
-54
lines changed

9 files changed

+871
-54
lines changed

helion/_compiler/compile_environment.py

Lines changed: 103 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from torch._inductor.utils import triton_type
1919
from torch._subclasses import FakeTensorMode
2020
from torch.fx.experimental.symbolic_shapes import ShapeEnv
21+
from torch.utils._sympy.symbol import SymT
22+
from torch.utils._sympy.symbol import symbol_is_type
2123

2224
from .. import exc
2325
from ..language.constexpr import ConstExpr
@@ -167,16 +169,23 @@ def allocate_block_size(
167169
reduction: bool = False,
168170
source: BlockSizeSource,
169171
hint: int = 64,
172+
reuse_var: torch.SymInt | None = None,
170173
) -> int:
171174
idx = len(self.block_sizes)
175+
# Use the provided var or create a new one
176+
var = (
177+
reuse_var
178+
if reuse_var is not None
179+
else self.create_block_var(
180+
f"block_size_{idx}" if not reduction else f"rdim_{idx}",
181+
hint=hint,
182+
)
183+
)
172184
self.block_sizes.append(
173185
info := BlockSizeInfo(
174186
block_id=idx,
175187
size=size,
176-
var=self.create_block_var(
177-
f"block_size_{idx}" if not reduction else f"rdim_{idx}",
178-
hint=hint,
179-
),
188+
var=var,
180189
reduction=reduction,
181190
block_size_source=source,
182191
)
@@ -185,37 +194,55 @@ def allocate_block_size(
185194
from .host_function import HostFunction
186195
from .host_function import SymbolOrigin
187196

188-
HostFunction.current().expr_to_origin[info.symbol()] = SymbolOrigin(
189-
origin=BlockSizeOrigin(idx),
190-
)
197+
# Only register in expr_to_origin if we created a new var
198+
# (otherwise the var is already registered under its original block)
199+
if reuse_var is None:
200+
HostFunction.current().expr_to_origin[info.symbol()] = SymbolOrigin(
201+
origin=BlockSizeOrigin(idx),
202+
)
191203
return idx
192204

193205
def allocate_reduction_dimension(self, size: torch.SymInt | int) -> BlockSizeInfo:
194206
# Check if this size is already a registered block size
207+
existing_block: BlockSizeInfo | None = None
195208
if isinstance(size, torch.SymInt):
196209
from .host_function import HostFunction
197210

198211
expr = size._sympy_()
199212
origin_info = HostFunction.current().expr_to_origin.get(expr)
200213
if origin_info and isinstance(origin_info.origin, BlockSizeOrigin):
201214
block_idx = origin_info.origin.block_id
202-
# Return the existing block size if it's a reduction dimension
203-
if self.block_sizes[block_idx].reduction:
204-
return self.block_sizes[block_idx]
215+
existing_block = self.block_sizes[block_idx]
216+
217+
def _is_unbacked_symint(x: int | torch.SymInt) -> bool:
218+
if not isinstance(x, torch.SymInt):
219+
return False
220+
expr = x._sympy_()
221+
if isinstance(expr, sympy.Symbol):
222+
return symbol_is_type(expr, SymT.UNBACKED_INT)
223+
return False
205224

206225
# Check for existing reduction dimensions with the same size
207226
for rdim in self.block_sizes:
208-
if rdim.reduction and rdim.size == size:
227+
if not rdim.reduction or not isinstance(rdim.size, (int, torch.SymInt)):
228+
continue
229+
if _is_unbacked_symint(rdim.size) and _is_unbacked_symint(size):
230+
if self.known_equal(rdim.size, size):
231+
return rdim
232+
elif rdim.size == size:
209233
return rdim
210234

211235
# Allocate a new reduction dimension
236+
# If size is already a block var, reuse it to maintain symbol identity
237+
reuse_var = existing_block.var if existing_block is not None else None
212238
rdim_idx = self.allocate_block_size(
213239
size,
214240
reduction=True,
215241
source=ReductionLoopBlockSizeSource(
216242
sum([int(bs.reduction) for bs in self.block_sizes])
217243
),
218244
hint=next_power_of_2(self.size_hint(size)),
245+
reuse_var=reuse_var,
219246
)
220247
return self.block_sizes[rdim_idx]
221248

@@ -272,6 +299,71 @@ def cached_create_unbacked_symint(
272299
self._symint_cache[key] = result
273300
return result
274301

302+
def _normalize_shape_to_block_vars(
303+
self, shape: list[int | torch.SymInt]
304+
) -> list[int | torch.SymInt]:
305+
"""Normalize shape dimensions to use canonical block size variables."""
306+
return [
307+
self.block_sizes[bid].var
308+
if (bid := self.get_block_id(s)) is not None
309+
else s
310+
for s in shape
311+
]
312+
313+
def should_broadcast_tensor_indexers(
314+
self, tensors: typing.Sequence[torch.Tensor]
315+
) -> bool:
316+
"""Check whether tensor indexers need broadcasting."""
317+
if not tensors:
318+
return False
319+
# 1D tensors with block-size dims don't need broadcasting
320+
if all(
321+
t.ndim == 1 and self.get_block_id(t.size(0)) is not None for t in tensors
322+
):
323+
return False
324+
# Single 1D tensor doesn't need broadcast handling
325+
return not (len(tensors) == 1 and tensors[0].ndim == 1)
326+
327+
def tensor_indexer_broadcast_shape(
328+
self, tensors: typing.Sequence[torch.Tensor]
329+
) -> list[int | torch.SymInt]:
330+
"""Compute broadcast shape for tensor indexers."""
331+
shapes = [list(t.size()) for t in tensors]
332+
if all(len(s) == 1 for s in shapes) and len(shapes) > 1: # Cartesian
333+
# Normalize each dimension to block size variable
334+
return self._normalize_shape_to_block_vars([s[0] for s in shapes])
335+
max_ndim = max(len(s) for s in shapes)
336+
padded = [([1] * (max_ndim - len(s)) + s) for s in shapes]
337+
result = [
338+
next((d for d in dims if self.size_hint(d) != 1), 1)
339+
for dims in zip(*padded, strict=True)
340+
]
341+
# Normalize the result to use canonical block size variables
342+
return self._normalize_shape_to_block_vars(result)
343+
344+
def tensor_indexer_dims(
345+
self, indexer_tensor: torch.Tensor
346+
) -> list[int | torch.SymInt]:
347+
"""Return dims contributed by a tensor indexer (non-broadcast case)."""
348+
non_trivial = [d for d in indexer_tensor.size() if self.size_hint(d) != 1]
349+
# Use size-based approach to find block_id
350+
bid = self.get_block_id(non_trivial[0]) if non_trivial else None
351+
if bid is not None:
352+
return [self.block_sizes[bid].var]
353+
return non_trivial or [1] # type: ignore[return-value]
354+
355+
def new_index_result(
356+
self, tensor: torch.Tensor, output_shape: typing.Sequence[int | torch.SymInt]
357+
) -> torch.Tensor:
358+
"""Create tensor for indexing ops with normalized shapes.
359+
360+
Uses size-based approach to normalize all dimensions that correspond
361+
to block sizes to their canonical variables.
362+
"""
363+
# Normalize all dimensions to canonical block size variables
364+
shape = self._normalize_shape_to_block_vars(list(output_shape))
365+
return tensor.new_empty(shape)
366+
275367
def to_fake(self, obj: object, origin: Origin) -> object:
276368
if obj is None:
277369
return None

helion/_compiler/indexing_strategy.py

Lines changed: 114 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,8 @@ def compute_shape(
575575
input_size = collections.deque(tensor.size())
576576
output_size = []
577577
env = CompileEnvironment.current()
578+
tensor_indexers = [k for k in index if isinstance(k, torch.Tensor)]
579+
should_broadcast = env.should_broadcast_tensor_indexers(tensor_indexers)
578580
k_index = 0
579581
for k in index:
580582
if k is None:
@@ -617,11 +619,14 @@ def compute_shape(
617619
else:
618620
output_size.append(1)
619621
k_index += 1
620-
elif isinstance(k, torch.Tensor) and (
621-
k.ndim == 1 or (len(index) == 1 and tensor.ndim == 1)
622-
):
622+
elif isinstance(k, torch.Tensor):
623623
input_size.popleft()
624-
output_size.extend(k.size())
624+
if not should_broadcast:
625+
output_size.extend(env.tensor_indexer_dims(k))
626+
elif k is tensor_indexers[0]:
627+
output_size.extend(
628+
env.tensor_indexer_broadcast_shape(tensor_indexers)
629+
)
625630
k_index += 1
626631
else:
627632
raise exc.InvalidIndexingType(k)
@@ -667,13 +672,87 @@ def create(
667672
output_size = SubscriptIndexing.compute_shape(fake_value, index, state)
668673
env = CompileEnvironment.current()
669674
dtype = env.triton_index_type()
675+
tensor_indexers = [k for k in index if isinstance(k, torch.Tensor)]
676+
should_broadcast = env.should_broadcast_tensor_indexers(tensor_indexers)
677+
broadcast_dims = 0
678+
if should_broadcast:
679+
broadcast_dims = len(env.tensor_indexer_broadcast_shape(tensor_indexers))
680+
is_cartesian = (
681+
broadcast_dims >= 2
682+
and len(tensor_indexers) == broadcast_dims
683+
and all(
684+
t.ndim == 1
685+
or sum(1 for d in t.size() if env.size_hint(d) != 1) <= 1
686+
for t in tensor_indexers
687+
)
688+
)
670689
if dtype == "tl.int32" and SubscriptIndexing._needs_int64(fake_value):
671690
raise exc.IndexOffsetOutOfRangeForInt32(env.index_dtype)
672691

673692
def _is_size_one(size: int | torch.SymInt) -> bool:
674693
return env.known_equal(size, 1)
675694

676695
k_index = 0
696+
697+
def handle_broadcast_tensor(
698+
position: int,
699+
index_elem: torch.Tensor,
700+
index_var: str,
701+
cur_output_idx: int,
702+
) -> tuple[str, dict[str, None]]:
703+
assert broadcast_dims > 0
704+
tensor_idx = next(
705+
i for i, t in enumerate(tensor_indexers) if t is index_elem
706+
)
707+
first_tensor_out_idx = (
708+
cur_output_idx if tensor_idx == 0 else cur_output_idx - broadcast_dims
709+
)
710+
non_trivial_output_positions: list[int] = []
711+
if is_cartesian:
712+
pos = first_tensor_out_idx + tensor_idx
713+
single_output_dim = True
714+
else:
715+
# Find position(s) where this tensor contributes non-trivial dims
716+
offset = max(0, broadcast_dims - index_elem.ndim)
717+
non_trivial_output_positions = [
718+
first_tensor_out_idx + offset + i
719+
for i in range(index_elem.ndim)
720+
if env.size_hint(index_elem.size(i)) != 1
721+
]
722+
pos = non_trivial_output_positions[0]
723+
single_output_dim = len(non_trivial_output_positions) <= 1
724+
725+
new_masks: dict[str, None] = {}
726+
if single_output_dim:
727+
expand = (
728+
tile_strategy.expand_str(output_size, pos)
729+
if index_elem.ndim == 1
730+
else ""
731+
)
732+
idx_val = f"({index_var}){expand}"
733+
else:
734+
# Multi-dim tensor with multiple non-trivial dims
735+
idx_val = f"({index_var})"
736+
if tensor_idx == 0:
737+
for p in non_trivial_output_positions:
738+
if (
739+
p < len(output_size)
740+
and (bid := env.get_block_id(output_size[p]))
741+
and (mv := state.codegen.mask_var(bid))
742+
and not _is_size_one(fake_value.size(len(index_values)))
743+
):
744+
new_masks.setdefault(
745+
f"({mv}){tile_strategy.expand_str(output_size, p)}"
746+
)
747+
# Padded iota mask
748+
if (
749+
orig_len := _get_padded_iota_original_length(state, position)
750+
) is not None:
751+
new_masks.setdefault(
752+
f"(({index_var} < {orig_len}){tile_strategy.expand_str(output_size, first_tensor_out_idx + tensor_idx)})"
753+
)
754+
return idx_val, new_masks
755+
677756
for n, k in enumerate(index):
678757
if k is None:
679758
output_idx += 1
@@ -752,40 +831,42 @@ def _is_size_one(size: int | torch.SymInt) -> bool:
752831
index_values.append(f"tl.zeros([1], {dtype}){expand}")
753832
output_idx += 1
754833
k_index += 1
755-
elif isinstance(k, torch.Tensor) and k.ndim == 1:
756-
expand = tile_strategy.expand_str(output_size, output_idx)
834+
elif isinstance(k, torch.Tensor):
757835
ast_index = state.ast_args[1]
758836
assert isinstance(ast_index, (list, tuple))
759-
assert len(ast_index) == len(index)
760837
index_var = state.codegen.lift(ast_index[n], prefix="index").id
838+
839+
# Use broadcast handling for: multiple tensors, or single tensor with ndim > 1
840+
if should_broadcast:
841+
idx_val, new_masks = handle_broadcast_tensor(
842+
n, k, index_var, output_idx
843+
)
844+
index_values.append(idx_val)
845+
mask_values.update(new_masks)
846+
if k is tensor_indexers[0]:
847+
output_idx += broadcast_dims
848+
k_index += 1
849+
continue
850+
851+
expand = (
852+
tile_strategy.expand_str(output_size, output_idx)
853+
if k.ndim < len(output_size)
854+
else ""
855+
)
761856
index_values.append(f"({index_var}){expand}")
762-
if (block_idx := env.get_block_id(output_size[output_idx])) is not None:
763-
if mask := state.codegen.mask_var(block_idx):
764-
mask_values.setdefault(f"({mask}){expand}")
765-
# Check if this index comes from a padded hl.arange and generate mask
766-
if (
767-
original_length := _get_padded_iota_original_length(state, n)
768-
) is not None:
769-
mask_values.setdefault(f"({index_var} < {original_length}){expand}")
770-
output_idx += 1
771-
k_index += 1
772-
elif (
773-
isinstance(k, torch.Tensor) and len(index) == 1 and fake_value.ndim == 1
774-
):
775-
# TODO(jansel): combine this case with the above
776-
ast_index = state.ast_args[1]
777-
assert isinstance(ast_index, (list, tuple))
778-
assert len(ast_index) == 1
779-
index_var = state.codegen.lift(ast_index[0], prefix="index").id
780-
index_values.append(index_var)
781-
output_idx += k.ndim
782-
for n, s in enumerate(output_size):
783-
if (block_idx := env.get_block_id(s)) is not None and (
784-
mask := state.codegen.mask_var(block_idx)
857+
mask_block_id = (
858+
env.get_block_id(output_size[output_idx])
859+
if output_idx < len(output_size)
860+
else None
861+
)
862+
if mask_block_id is not None:
863+
mask_var = state.codegen.mask_var(mask_block_id)
864+
if mask_var and not _is_size_one(
865+
fake_value.size(len(index_values) - 1)
785866
):
786-
mask_values.setdefault(
787-
f"({mask}){tile_strategy.expand_str(output_size, n)}"
788-
)
867+
mask_values.setdefault(f"({mask_var}){expand}")
868+
869+
output_idx += k.ndim
789870
k_index += 1
790871
else:
791872
raise exc.InvalidIndexingType(type(k))

0 commit comments

Comments
 (0)