Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions benchmarks/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,11 @@ class RunResult:
"examples.mamba2_chunk_state",
"helion_mamba2_chunk_state_kernel",
),
"gdn_fwd_h": (
"tritonbench.operators.gdn_fwd_h.operator",
"examples.gdn_fwd_h",
"helion_gdn_fwd_h_tb",
),
}


Expand Down Expand Up @@ -651,6 +656,13 @@ class RunResult:
"helion_mamba2_chunk_state_kernel_speedup": "helion_speedup",
"helion_mamba2_chunk_state_kernel_accuracy": "helion_accuracy",
},
"gdn_fwd_h": {
"eager": "baseline",
"compile_speedup": "torch_compile_speedup",
"compile_accuracy": "torch_compile_accuracy",
"helion_gdn_fwd_h_speedup": "helion_speedup",
"helion_gdn_fwd_h_accuracy": "helion_accuracy",
},
}


Expand Down
210 changes: 210 additions & 0 deletions examples/gdn_fwd_h.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
"""
Gated Delta Net Fwd H Kernel
============================

This code implements a fwd_h kernel as used in gated delta net
"""

# %%
# Imports
# -------
from __future__ import annotations

import math
from typing import Callable

import torch

import helion
from helion._testing import DEVICE
from helion._testing import run_example
import helion.language as hl


# %%
# Helion Kernel Implementation
# ----------------------------
@helion.kernel()
def helion_gdn_fwd_h(
k: torch.Tensor, w: torch.Tensor, u: torch.Tensor, g: torch.Tensor, chunk_size: int
) -> torch.Tensor:
"""
Argument:
k: (batch, seqlen, nheads, dhead)
w: (batch, seqlen, nheads, dhead)
u: (batch, seqlen, nheads, expand_v*dhead)
g: (batch, seqlen, nheads)
chunk_size: int
Return:
h: (batch, nchunks, nheads, dhead, expand_v*dhead)
"""

batch, seqlen, nheads, dhead = k.shape
dhead = hl.specialize(dhead)
chunk_size = hl.specialize(chunk_size)
dstate = u.shape[-1]

acc_dtype = torch.float32
dtype = k.dtype

nchunks = (seqlen + chunk_size - 1) // chunk_size
h = torch.empty(batch, nchunks, nheads, dhead, dstate, dtype=dtype, device=k.device)
block_v = hl.register_block_size(dstate)

for tile_b, tile_h, tile_v in hl.tile(
[batch, nheads, dstate], block_size=[1, 1, block_v]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is it specialized on 1 out of curiosity? Also why are all of the rest using tile_b.begin instead of just tile_b? Feels a bit ugly 🤔

Copy link
Contributor Author

@v0i0 v0i0 Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems like a good way (to me) to get a grid for purely batched portions of the grid (setting block size 1).
the advantage of using begin then means that the resulting tensors are indexed not sliced, i.e. a[t.begin, :] is 1-d but a[t, :] is a 2-d tensor (with first dim 1). triton empirically is a lot more happy with low-dim tensors, and less typing. you basically get vmap-like syntax where you don't need to worry about batched indices once they're indexed out.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm... why not just use hl.grid then?

I understand what you're saying, but I think the resulting code reads a bit ugly/hacky. For example, I would prefer syntax perhaps like

for idx_b, idx_h, tile_v in hl.tile([batch, nheads, dstate], block_size=[0, 0, block_v])

I think the littering of tile_b.begin is semantically confusing.

Also, in terms of lower-dim tensors, I would prefer to just autotune over that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 i think you had ideas about not needing x.begin, right?

):
b_h = hl.zeros([dhead, tile_v], dtype=acc_dtype)
for t_i in hl.tile(seqlen, block_size=chunk_size):
h[tile_b.begin, t_i.id, tile_h.begin, :, tile_v] = b_h.to(dtype)
b_w = w[tile_b.begin, t_i, tile_h.begin, :]
c_h = b_h.to(dtype)
b_v = hl.dot(b_w, c_h, out_dtype=acc_dtype)
p_v = u[tile_b.begin, t_i, tile_h.begin, tile_v].to(acc_dtype)
b_v = p_v - b_v
m_t = t_i.index < seqlen
t_i_last = min(t_i.begin + chunk_size, seqlen) - 1
b_g_last = g[tile_b.begin, t_i_last, tile_h.begin].to(acc_dtype)
b_g = g[tile_b.begin, t_i, tile_h.begin].to(acc_dtype)
b_v *= torch.where(m_t, torch.exp(b_g_last - b_g), 0)[:, None]
b_g_last = torch.exp(b_g_last)
b_h *= b_g_last
b_v = b_v.to(dtype)
p_k = k[tile_b.begin, t_i, tile_h.begin, :]
b_h = hl.dot(p_k.T, b_v, acc=b_h)
return h


def helion_gdn_fwd_h_tb(
tb_obj: object,
k: torch.Tensor,
w: torch.Tensor,
u: torch.Tensor,
g: torch.Tensor,
chunk_size: int,
) -> Callable[[], torch.Tensor]:
"""
Argument:
k: (batch, seqlen, nheads, dhead)
w: (batch, seqlen, nheads, dhead)
u: (batch, seqlen, nheads, expand_v*dhead)
g: (batch, seqlen, nheads)
chunk_size: int
Return:
h: (batch, nchunks, nheads, dhead, expand_v*dhead)
"""
return lambda: helion_gdn_fwd_h(k, w, u, g, chunk_size)


# %%
# Reference Function
# -------------
def ref_gdn_fwd_h(
k: torch.Tensor, w: torch.Tensor, u: torch.Tensor, g: torch.Tensor, chunk_size: int
) -> torch.Tensor:
"""
Argument:
k: (batch, seqlen, nheads, dhead)
w: (batch, seqlen, nheads, dhead)
u: (batch, seqlen, nheads, expand_v*dhead)
g: (batch, seqlen, nheads)
chunk_size: int
Return:
h: (batch, nchunks, nheads, dhead, expand_v*dhead)
"""

batch, seqlen, nheads, dhead = k.shape
expand_v = u.shape[-1] // dhead
nchunks = (seqlen + chunk_size - 1) // chunk_size

acc_dtype = torch.float32
dtype = k.dtype

h = torch.empty(
batch, nchunks, nheads, dhead, expand_v * dhead, dtype=k.dtype, device=k.device
)
b_h = torch.zeros(
batch, nheads, dhead, expand_v * dhead, dtype=acc_dtype, device=k.device
)

k_c = k.reshape(batch, nchunks, chunk_size, nheads, dhead)
w_c = w.reshape(batch, nchunks, chunk_size, nheads, dhead)
u_c = u.reshape(batch, nchunks, chunk_size, nheads, expand_v * dhead)
g_c = g.reshape(batch, nchunks, chunk_size, nheads)
for i_t in range(nchunks):
h[:, i_t, :, :, :] = b_h.to(dtype)
b_w = w_c[:, i_t, :, :, :].to(acc_dtype)
c_h = b_h.to(dtype).to(acc_dtype)
b_v = torch.einsum("bchk,bhkv->bchv", b_w, c_h)
p_v = u_c[:, i_t, :, :, :].to(acc_dtype)
b_v = p_v - b_v
last_idx = min((i_t + 1) * chunk_size, seqlen) - 1
m_t = (i_t * chunk_size + torch.arange(0, chunk_size, device=k.device)) < seqlen
b_g_last = g[:, last_idx, :].to(acc_dtype)
b_g = g_c[:, i_t, :, :].to(acc_dtype) # batch, chunk, nheads
b_v *= torch.where(
m_t.unsqueeze(0).unsqueeze(-1), torch.exp(b_g_last.unsqueeze(1) - b_g), 0
).unsqueeze(-1)
b_g_last = torch.exp(b_g_last)
b_h *= b_g_last.unsqueeze(-1).unsqueeze(-1)
b_v = b_v.to(dtype).to(acc_dtype)
p_k = k_c[:, i_t, :, :, :].to(acc_dtype)
b_h += torch.einsum("bchk,bchv->bhkv", p_k, b_v)
return h


# %%
# Testing Function
# -------------
def test(
batch: int,
nheads: int,
seqlen: int,
chunk_size: int,
dhead: int,
dstate: int,
dtype: torch.dtype = torch.float16,
) -> None:
k = torch.randn(batch, seqlen, nheads, dhead, dtype=torch.bfloat16, device=DEVICE)
k = torch.nn.functional.rms_norm(k, (dhead,))
w = torch.randn(
batch,
seqlen // chunk_size,
chunk_size,
nheads,
dhead,
dtype=torch.float32,
device=DEVICE,
)
# w = torch.nn.functional.rms_norm(w.to(torch.bfloat16), (dhead,))
wu, ws, wv = torch.linalg.svd(w.permute(0, 1, 3, 2, 4), full_matrices=False)
w = torch.einsum("bnhik,bnhkj->bnhij", wu, wv)
w = (
w.permute(0, 1, 3, 2, 4)
.reshape(batch, seqlen, nheads, dhead)
.to(torch.bfloat16)
)
u = torch.randn(batch, seqlen, nheads, dstate, dtype=torch.bfloat16, device=DEVICE)
u = torch.nn.functional.rms_norm(u, (dstate,))
g = torch.cumsum(
0.5
* math.log(1 / dhead)
* torch.rand(batch, seqlen, nheads, dtype=torch.float32, device=DEVICE),
dim=1,
)
args = (k, w, u, g, chunk_size)
run_example(helion_gdn_fwd_h, ref_gdn_fwd_h, args)


# %%
# Main Function
# -----------
def main() -> None:
"""
Main entry point that runs the attention kernel test with specific parameters.
"""
test(8, 80, 4096, 256, 64, 128)


if __name__ == "__main__":
main()
124 changes: 124 additions & 0 deletions test/test_examples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -1708,6 +1708,130 @@ def fused_linear_jsd_kernel(beta: float, ignore_index: int, temperature: float,
# src[fused_linear_jsd.py:N]: return (loss / student_logits.shape[0]).sum()
return (loss / student_logits.shape[0]).sum()

--- assertExpectedJournal(TestExamples.test_gdn_fwd_h)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_helion_gdn_fwd_h(h, w, u, g, k, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_4: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
# src[gdn_fwd_h.py:N]: for tile_b, tile_h, tile_v in hl.tile(
# src[gdn_fwd_h.py:N]: [batch, nheads, dstate], block_size=[1, 1, block_v]
# src[gdn_fwd_h.py:N]: ):
num_blocks_0 = 8
num_blocks_1 = 80
pid_0 = tl.program_id(0) % num_blocks_0
pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1
pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1)
offset_1 = pid_0
offset_2 = pid_1
offset_0 = pid_2 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
indices_5 = tl.arange(0, _RDIM_SIZE_4).to(tl.int32)
# src[gdn_fwd_h.py:N]: b_h = hl.zeros([dhead, tile_v], dtype=acc_dtype)
b_h = tl.full([64, _BLOCK_SIZE_0], 0.0, tl.float32)
# src[gdn_fwd_h.py:N]: for t_i in hl.tile(seqlen, block_size=chunk_size):
# src[gdn_fwd_h.py:N]: h[tile_b.begin, t_i.id, tile_h.begin, :, tile_v] = b_h.to(dtype)
# src[gdn_fwd_h.py:N]: b_w = w[tile_b.begin, t_i, tile_h.begin, :]
# src[gdn_fwd_h.py:N-N]: ...
for offset_4 in tl.range(0, 4096, _BLOCK_SIZE_3):
indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
b_h_copy = b_h
b_h_copy_0 = b_h_copy
# src[gdn_fwd_h.py:N]: h[tile_b.begin, t_i.id, tile_h.begin, :, tile_v] = b_h.to(dtype)
v_0 = tl.cast(b_h_copy_0, tl.bfloat16)
tile_id = offset_4 // _BLOCK_SIZE_3
tl.store(h + (offset_1 * 10485760 + tile_id * 655360 + offset_2 * 8192 + indices_5[:, None] * 128 + indices_0[None, :] * 1), v_0, None)
# src[gdn_fwd_h.py:N]: b_w = w[tile_b.begin, t_i, tile_h.begin, :]
b_w = tl.load(w + (offset_1 * 20971520 + indices_4[:, None] * 5120 + offset_2 * 64 + indices_5[None, :] * 1), None)
# src[gdn_fwd_h.py:N]: c_h = b_h.to(dtype)
v_1 = tl.cast(b_h_copy_0, tl.bfloat16)
# src[gdn_fwd_h.py:N]: b_v = hl.dot(b_w, c_h, out_dtype=acc_dtype)
b_v = tl.dot(tl.cast(b_w, tl.bfloat16), tl.cast(v_1, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32)
# src[gdn_fwd_h.py:N]: p_v = u[tile_b.begin, t_i, tile_h.begin, tile_v].to(acc_dtype)
load_1 = tl.load(u + (offset_1 * 41943040 + indices_4[:, None] * 10240 + offset_2 * 128 + indices_0[None, :] * 1), None)
v_2 = tl.cast(load_1, tl.float32)
# src[gdn_fwd_h.py:N]: b_v = p_v - b_v
v_3 = v_2 - b_v
# src[gdn_fwd_h.py:N]: m_t = t_i.index < seqlen
v_4 = tl.full([], 4096, tl.int32)
v_5 = indices_4 < v_4
# src[gdn_fwd_h.py:N]: t_i_last = min(t_i.begin + chunk_size, seqlen) - 1
sub_1 = -1 + (4096 * (4096 <= 256 + offset_4) + (256 + offset_4) * (256 + offset_4 < 4096))
# src[gdn_fwd_h.py:N]: b_g_last = g[tile_b.begin, t_i_last, tile_h.begin].to(acc_dtype)
b_g_last = tl.load(g + (offset_1 * 327680 + sub_1 * 80 + offset_2 * 1), None)
# src[gdn_fwd_h.py:N]: b_g = g[tile_b.begin, t_i, tile_h.begin].to(acc_dtype)
b_g = tl.load(g + (offset_1 * 327680 + indices_4 * 80 + offset_2 * 1), None)
# src[gdn_fwd_h.py:N]: b_v *= torch.where(m_t, torch.exp(b_g_last - b_g), 0)[:, None]
v_6 = b_g_last[None]
v_7 = v_6 - b_g
v_8 = libdevice.exp(v_7)
v_9 = 0.0
v_10 = v_9[None]
v_11 = tl.where(v_5, v_8, v_10)
subscript = v_11[:, None]
v_12 = v_3 * subscript
# src[gdn_fwd_h.py:N]: b_g_last = torch.exp(b_g_last)
v_13 = libdevice.exp(b_g_last)
# src[gdn_fwd_h.py:N]: b_h *= b_g_last
v_14 = v_13[None, None]
v_15 = b_h_copy_0 * v_14
# src[gdn_fwd_h.py:N]: b_v = b_v.to(dtype)
v_16 = tl.cast(v_12, tl.bfloat16)
# src[gdn_fwd_h.py:N]: p_k = k[tile_b.begin, t_i, tile_h.begin, :]
p_k = tl.load(k + (offset_1 * 20971520 + indices_4[:, None] * 5120 + offset_2 * 64 + indices_5[None, :] * 1), None)
# src[gdn_fwd_h.py:N]: b_h = hl.dot(p_k.T, b_v, acc=b_h)
permute = tl.permute(p_k, [1, 0])
b_h = tl.dot(tl.cast(permute, tl.bfloat16), tl.cast(v_16, tl.bfloat16), acc=v_15, input_precision='tf32', out_dtype=tl.float32)

def helion_gdn_fwd_h(k: torch.Tensor, w: torch.Tensor, u: torch.Tensor, g: torch.Tensor, chunk_size: int, *, _launcher=_default_launcher):
"""
Argument:
k: (batch, seqlen, nheads, dhead)
w: (batch, seqlen, nheads, dhead)
u: (batch, seqlen, nheads, expand_v*dhead)
g: (batch, seqlen, nheads)
chunk_size: int
Return:
h: (batch, nchunks, nheads, dhead, expand_v*dhead)
"""
# src[gdn_fwd_h.py:N]: batch, seqlen, nheads, dhead = k.shape
batch, seqlen, nheads, dhead = k.shape
# src[gdn_fwd_h.py:N]: dhead = hl.specialize(dhead)
dhead = 64
# src[gdn_fwd_h.py:N]: chunk_size = hl.specialize(chunk_size)
chunk_size = 256
# src[gdn_fwd_h.py:N]: dstate = u.shape[-1]
dstate = u.shape[-1]
# src[gdn_fwd_h.py:N]: acc_dtype = torch.float32
acc_dtype = torch.float32
# src[gdn_fwd_h.py:N]: dtype = k.dtype
dtype = k.dtype
# src[gdn_fwd_h.py:N]: nchunks = (seqlen + chunk_size - 1) // chunk_size
nchunks = (seqlen + chunk_size - 1) // chunk_size
# src[gdn_fwd_h.py:N]: h = torch.empty(batch, nchunks, nheads, dhead, dstate, dtype=dtype, device=k.device)
h = torch.empty(batch, nchunks, nheads, dhead, dstate, dtype=dtype, device=k.device)
# src[gdn_fwd_h.py:N]: for tile_b, tile_h, tile_v in hl.tile(
# src[gdn_fwd_h.py:N]: [batch, nheads, dstate], block_size=[1, 1, block_v]
# src[gdn_fwd_h.py:N]: ):
_BLOCK_SIZE_0 = 32
_RDIM_SIZE_4 = 64
# src[gdn_fwd_h.py:N]: for t_i in hl.tile(seqlen, block_size=chunk_size):
# src[gdn_fwd_h.py:N]: h[tile_b.begin, t_i.id, tile_h.begin, :, tile_v] = b_h.to(dtype)
# src[gdn_fwd_h.py:N]: b_w = w[tile_b.begin, t_i, tile_h.begin, :]
# src[gdn_fwd_h.py:N-N]: ...
_BLOCK_SIZE_3 = 256
# src[gdn_fwd_h.py:N]: for tile_b, tile_h, tile_v in hl.tile(
# src[gdn_fwd_h.py:N]: [batch, nheads, dstate], block_size=[1, 1, block_v]
# src[gdn_fwd_h.py:N]: ):
# src[gdn_fwd_h.py:N-N]: ...
_launcher(_helion_helion_gdn_fwd_h, (8 * 80 * triton.cdiv(128, _BLOCK_SIZE_0),), h, w, u, g, k, _BLOCK_SIZE_0, _RDIM_SIZE_4, _BLOCK_SIZE_3, num_warps=4, num_stages=1)
# src[gdn_fwd_h.py:N]: return h
return h

--- assertExpectedJournal(TestExamples.test_geglu)
from __future__ import annotations

Expand Down
Loading
Loading