-
Notifications
You must be signed in to change notification settings - Fork 72
Closed
Description
Describe the bug
Looks like the expression with the min is getting hoisted out of the loop despite depending on the loop induction variable. Feels like maybe the handling for attn hoisting is coming back to us?
To Reproduce
On top of #1119:
+git diff
diff --git a/examples/gdn_fwd_h.py b/examples/gdn_fwd_h.py
index 45e11f0..5512366 100644
--- a/examples/gdn_fwd_h.py
+++ b/examples/gdn_fwd_h.py
@@ -63,7 +63,7 @@ def helion_gdn_fwd_h(
p_v = u[tile_b.begin, t_i, tile_h.begin, tile_v].to(acc_dtype)
b_v = p_v - b_v
m_t = t_i.index < seqlen
- t_i_last = min(t_i.begin + chunk_size, seqlen) - 1
+ t_i_last = min(t_i.begin + chunk_size - 1, seqlen - 1)
b_g_last = g[tile_b.begin, t_i_last, tile_h.begin].to(acc_dtype)
b_g = g[tile_b.begin, t_i, tile_h.begin].to(acc_dtype)
b_v *= torch.where(m_t, torch.exp(b_g_last - b_g), 0)[:, None]
+python examples/gdn_fwd_h.py
Testing helion correctness...
[0s] Autotune random seed: 2727137739
Traceback (most recent call last):
File "/data/users/mhoehnerbach/projects/linear-attention/helion/helion/autotuner/base_search.py", line 188, in _compute_baseline
baseline_output = self.kernel.compile_config(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/torchinductor_mhoehnerbach/pj/cpjaap23wmi72e2hnjsd4ekh2kqwxu4glx5tiikiw7mr2brcnvae.py", line 121, in helion_gdn_fwd_h
_launcher(_helion_helion_gdn_fwd_h, (8 * 80 * triton.cdiv(128, _BLOCK_SIZE_0),), h, w, u, g, k, _BLOCK_SIZE_0, _RDIM_SIZE_4, _BLOCK_SIZE_3, num_warps=4, num_stages=1)
File "/data/users/mhoehnerbach/projects/linear-attention/helion/helion/runtime/__init__.py", line 86, in default_launcher
return triton_kernel.run(
^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/linear-attention/venv-new-tilelang/lib/python3.12/site-packages/triton/runtime/jit.py", line 733, in run
kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/linear-attention/venv-new-tilelang/lib/python3.12/site-packages/triton/runtime/jit.py", line 861, in _do_compile
kernel = self.compile(src, target=target, options=options.__dict__)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/linear-attention/venv-new-tilelang/lib/python3.12/site-packages/triton/compiler/compiler.py", line 300, in compile
module = src.make_ir(target, options, codegen_fns, module_map, context)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/linear-attention/venv-new-tilelang/lib/python3.12/site-packages/triton/compiler/compiler.py", line 80, in make_ir
return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
triton.compiler.errors.CompilationError: at 18:38:
pid_0 = tl.program_id(0) % num_blocks_0
pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1
pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1)
offset_1 = pid_0
offset_2 = pid_1
offset_0 = pid_2 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
indices_5 = tl.arange(0, _RDIM_SIZE_4).to(tl.int32)
# src[gdn_fwd_h.py:57]: b_h = hl.zeros([dhead, tile_v], dtype=acc_dtype)
b_h = tl.full([64, _BLOCK_SIZE_0], 0.0, tl.float32)
# src[gdn_fwd_h.py:67]: b_g_last = g[tile_b.begin, t_i_last, tile_h.begin].to(acc_dtype)
symnode_0 = 4095 * (4095 <= 255 + offset_4) + (255 + offset_4) * (255 + offset_4 < 4095)
^
NameError('offset_4 is not defined')
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/data/users/mhoehnerbach/projects/linear-attention/helion/examples/gdn_fwd_h.py", line 210, in <module>
main()
File "/data/users/mhoehnerbach/projects/linear-attention/helion/examples/gdn_fwd_h.py", line 206, in main
test(8, 80, 4096, 256, 64, 128)
File "/data/users/mhoehnerbach/projects/linear-attention/helion/examples/gdn_fwd_h.py", line 196, in test
run_example(helion_gdn_fwd_h, ref_gdn_fwd_h, args)
File "/data/users/mhoehnerbach/projects/linear-attention/helion/helion/_testing.py", line 607, in run_example
func(*args).to(torch.float32),
^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/linear-attention/helion/helion/runtime/kernel.py", line 330, in __call__
return self.bind(args)(*args)
^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/linear-attention/helion/helion/runtime/kernel.py", line 696, in __call__
self.autotune(args, force=False)
File "/data/users/mhoehnerbach/projects/linear-attention/helion/helion/runtime/kernel.py", line 574, in autotune
config = self.settings.autotuner_fn(self, args, **kwargs).autotune(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/linear-attention/helion/helion/runtime/settings.py", line 253, in default_autotuner_fn
return cache_cls(autotuner_cls(bound_kernel, args, **kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/linear-attention/helion/helion/autotuner/pattern_search.py", line 45, in __init__
super().__init__(kernel, args)
File "/data/users/mhoehnerbach/projects/linear-attention/helion/helion/autotuner/base_search.py", line 707, in __init__
super().__init__(kernel, args)
File "/data/users/mhoehnerbach/projects/linear-attention/helion/helion/autotuner/base_search.py", line 136, in __init__
) = self._compute_baseline()
^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mhoehnerbach/projects/linear-attention/helion/helion/autotuner/base_search.py", line 203, in _compute_baseline
raise exc.InvalidConfig(
helion.exc.InvalidConfig: Default config failed while computing baseline.
Default config: @helion.kernel(config=helion.Config(block_sizes=[32], indexing=['pointer', 'pointer', 'pointer', 'pointer', 'pointer', 'pointer'], l2_groupings=[1], load_eviction_policies=['', '', '', '', ''], loop_orders=[[0, 1, 2]], num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, None], range_num_stages=[], range_unroll_factors=[0, 0], range_warp_specializes=[]), static_shapes=True)
Enable HELION_AUTOTUNE_LOG_LEVEL=DEBUG to log generated Triton code.
To work around this error, you could set `@helion.kernel(autotune_baseline_fn=...)` to provide a custom baseline function (e.g. PyTorch eager implementation of your kernel).
Metadata
Metadata
Assignees
Labels
No labels