Skip to content

Commit fae0328

Browse files
committed
example: gdn_fwd_h
1 parent 51580b4 commit fae0328

File tree

2 files changed

+249
-0
lines changed

2 files changed

+249
-0
lines changed

benchmarks/run.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,11 @@ class RunResult:
336336
"examples.mamba2_chunk_state",
337337
"helion_mamba2_chunk_state_kernel",
338338
),
339+
"gdn_fwd_h": (
340+
"tritonbench.operators.gdn_fwd_h.operator",
341+
"examples.gdn_fwd_h",
342+
"helion_gdn_fwd_h_tb",
343+
),
339344
}
340345

341346

@@ -652,6 +657,13 @@ class RunResult:
652657
"helion_mamba2_chunk_state_kernel_speedup": "helion_speedup",
653658
"helion_mamba2_chunk_state_kernel_accuracy": "helion_accuracy",
654659
},
660+
"gdn_fwd_h": {
661+
"eager": "baseline",
662+
"compile_speedup": "torch_compile_speedup",
663+
"compile_accuracy": "torch_compile_accuracy",
664+
"helion_gdn_fwd_h_speedup": "helion_speedup",
665+
"helion_gdn_fwd_h_accuracy": "helion_accuracy",
666+
},
655667
}
656668

657669

examples/gdn_fwd_h.py

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
"""
2+
Gated Delta Net Fwd H Kernel
3+
============================
4+
5+
This code implements a fwd_h kernel as used in gated delta net
6+
"""
7+
8+
# %%
9+
# Imports
10+
# -------
11+
from __future__ import annotations
12+
13+
import math
14+
from typing import Callable
15+
16+
import torch
17+
18+
import helion
19+
from helion._testing import DEVICE
20+
from helion._testing import run_example
21+
import helion.language as hl
22+
23+
24+
# %%
25+
# Helion Kernel Implementation
26+
# ----------------------------
27+
@helion.kernel()
28+
def helion_gdn_fwd_h_kernel(
29+
k_c: torch.Tensor, w_c: torch.Tensor, u_c: torch.Tensor, g_c: torch.Tensor
30+
) -> torch.Tensor:
31+
"""
32+
Argument:
33+
k_c: (batch, nchunks, chunk_size, nheads, dhead)
34+
w_c: (batch, nchunks, chunk_size, nheads, dhead)
35+
u_c: (batch, nchunks, chunk_size, nheads, expand_v*dhead)
36+
g_c: (batch, nchunks, chunk_size, nheads)
37+
Return:
38+
h: (batch, nchunks, nheads, dhead, expand_v*dhead)
39+
"""
40+
41+
batch, nchunks, chunk_size, nheads, dhead = k_c.shape
42+
dhead = hl.specialize(dhead)
43+
chunk_size = hl.specialize(chunk_size)
44+
dstate = u_c.shape[-1]
45+
46+
acc_dtype = torch.float32
47+
dtype = k_c.dtype
48+
49+
h = torch.empty(
50+
batch, nchunks, nheads, dhead, dstate, dtype=dtype, device=k_c.device
51+
)
52+
block_v = hl.register_block_size(dstate)
53+
seqlen = chunk_size * nchunks
54+
55+
for tile_b, tile_h, tile_v in hl.tile(
56+
[batch, nheads, dstate], block_size=[1, 1, block_v]
57+
):
58+
b_h = hl.zeros([dhead, tile_v], dtype=acc_dtype)
59+
for i_t in range(nchunks):
60+
h[tile_b.begin, i_t, tile_h.begin, :, tile_v] = b_h.to(dtype)
61+
b_w = w_c[tile_b.begin, i_t, :, tile_h.begin, :]
62+
c_h = b_h.to(dtype)
63+
b_v = hl.dot(b_w, c_h, out_dtype=acc_dtype)
64+
p_v = u_c[tile_b.begin, i_t, :, tile_h.begin, tile_v].to(acc_dtype)
65+
b_v = p_v - b_v
66+
m_t = (i_t * chunk_size + hl.arange(0, chunk_size)) < seqlen
67+
b_g_last = g_c[tile_b.begin, i_t, chunk_size - 1, tile_h.begin].to(
68+
acc_dtype
69+
)
70+
b_g = g_c[tile_b.begin, i_t, :, tile_h.begin].to(acc_dtype)
71+
b_v *= torch.where(m_t, torch.exp(b_g_last - b_g), 0)[:, None]
72+
b_g_last = torch.exp(b_g_last)
73+
b_h *= b_g_last
74+
b_v = b_v.to(dtype)
75+
p_k = k_c[tile_b.begin, i_t, :, tile_h.begin, :]
76+
b_h = hl.dot(p_k.T, b_v, acc=b_h)
77+
return h
78+
79+
80+
def helion_gdn_fwd_h(
81+
k: torch.Tensor, w: torch.Tensor, u: torch.Tensor, g: torch.Tensor, chunk_size: int
82+
) -> torch.Tensor:
83+
"""
84+
Argument:
85+
k: (batch, seqlen, nheads, dhead)
86+
w: (batch, seqlen, nheads, dhead)
87+
u: (batch, seqlen, nheads, expand_v*dhead)
88+
g: (batch, seqlen, nheads)
89+
chunk_size: int
90+
Return:
91+
h: (batch, nchunks, nheads, dhead, expand_v*dhead)
92+
"""
93+
94+
batch, seqlen, nheads, dhead = k.shape
95+
dstate = u.shape[-1]
96+
nchunks = (seqlen + chunk_size - 1) // chunk_size
97+
98+
k_c = k.reshape(batch, nchunks, chunk_size, nheads, dhead)
99+
w_c = w.reshape(batch, nchunks, chunk_size, nheads, dhead)
100+
u_c = u.reshape(batch, nchunks, chunk_size, nheads, dstate)
101+
g_c = g.reshape(batch, nchunks, chunk_size, nheads)
102+
return helion_gdn_fwd_h_kernel(k_c, w_c, u_c, g_c)
103+
104+
105+
def helion_gdn_fwd_h_tb(
106+
tb_obj: object,
107+
k: torch.Tensor,
108+
w: torch.Tensor,
109+
u: torch.Tensor,
110+
g: torch.Tensor,
111+
chunk_size: int,
112+
) -> Callable[[], torch.Tensor]:
113+
"""
114+
Argument:
115+
k: (batch, seqlen, nheads, dhead)
116+
w: (batch, seqlen, nheads, dhead)
117+
u: (batch, seqlen, nheads, expand_v*dhead)
118+
g: (batch, seqlen, nheads)
119+
chunk_size: int
120+
Return:
121+
h: (batch, nchunks, nheads, dhead, expand_v*dhead)
122+
"""
123+
return lambda: helion_gdn_fwd_h(k, w, u, g, chunk_size)
124+
125+
126+
# %%
127+
# Reference Function
128+
# -------------
129+
def ref_gdn_fwd_h(
130+
k: torch.Tensor, w: torch.Tensor, u: torch.Tensor, g: torch.Tensor, chunk_size: int
131+
) -> torch.Tensor:
132+
"""
133+
Argument:
134+
k: (batch, seqlen, nheads, dhead)
135+
w: (batch, seqlen, nheads, dhead)
136+
u: (batch, seqlen, nheads, expand_v*dhead)
137+
g: (batch, seqlen, nheads)
138+
chunk_size: int
139+
Return:
140+
h: (batch, nchunks, nheads, dhead, expand_v*dhead)
141+
"""
142+
143+
batch, seqlen, nheads, dhead = k.shape
144+
expand_v = u.shape[-1] // dhead
145+
nchunks = (seqlen + chunk_size - 1) // chunk_size
146+
147+
acc_dtype = torch.float32
148+
dtype = k.dtype
149+
150+
h = torch.empty(
151+
batch, nchunks, nheads, dhead, expand_v * dhead, dtype=k.dtype, device=k.device
152+
)
153+
b_h = torch.zeros(
154+
batch, nheads, dhead, expand_v * dhead, dtype=acc_dtype, device=k.device
155+
)
156+
157+
k_c = k.reshape(batch, nchunks, chunk_size, nheads, dhead)
158+
w_c = w.reshape(batch, nchunks, chunk_size, nheads, dhead)
159+
u_c = u.reshape(batch, nchunks, chunk_size, nheads, expand_v * dhead)
160+
g_c = g.reshape(batch, nchunks, chunk_size, nheads)
161+
for i_t in range(nchunks):
162+
h[:, i_t, :, :, :] = b_h.to(dtype)
163+
b_w = w_c[:, i_t, :, :, :].to(acc_dtype)
164+
c_h = b_h.to(dtype).to(acc_dtype)
165+
b_v = torch.einsum("bchk,bhkv->bchv", b_w, c_h)
166+
p_v = u_c[:, i_t, :, :, :].to(acc_dtype)
167+
b_v = p_v - b_v
168+
last_idx = min((i_t + 1) * chunk_size, seqlen) - 1
169+
m_t = (i_t * chunk_size + torch.arange(0, chunk_size, device=k.device)) < seqlen
170+
b_g_last = g[:, last_idx, :].to(acc_dtype)
171+
b_g = g_c[:, i_t, :, :].to(acc_dtype) # batch, chunk, nheads
172+
b_v *= torch.where(
173+
m_t.unsqueeze(0).unsqueeze(-1), torch.exp(b_g_last.unsqueeze(1) - b_g), 0
174+
).unsqueeze(-1)
175+
b_g_last = torch.exp(b_g_last)
176+
b_h *= b_g_last.unsqueeze(-1).unsqueeze(-1)
177+
b_v = b_v.to(dtype).to(acc_dtype)
178+
p_k = k_c[:, i_t, :, :, :].to(acc_dtype)
179+
b_h += torch.einsum("bchk,bchv->bhkv", p_k, b_v)
180+
return h
181+
182+
183+
# %%
184+
# Testing Function
185+
# -------------
186+
def test(
187+
batch: int,
188+
nheads: int,
189+
seqlen: int,
190+
chunk_size: int,
191+
dhead: int,
192+
dstate: int,
193+
dtype: torch.dtype = torch.float16,
194+
) -> None:
195+
k = torch.randn(batch, seqlen, nheads, dhead, dtype=torch.bfloat16, device=DEVICE)
196+
k = torch.nn.functional.rms_norm(k, (dhead,))
197+
w = torch.randn(
198+
batch,
199+
seqlen // chunk_size,
200+
chunk_size,
201+
nheads,
202+
dhead,
203+
dtype=torch.float32,
204+
device=DEVICE,
205+
)
206+
# w = torch.nn.functional.rms_norm(w.to(torch.bfloat16), (dhead,))
207+
wu, ws, wv = torch.linalg.svd(w.permute(0, 1, 3, 2, 4), full_matrices=False)
208+
w = torch.einsum("bnhik,bnhkj->bnhij", wu, wv)
209+
w = (
210+
w.permute(0, 1, 3, 2, 4)
211+
.reshape(batch, seqlen, nheads, dhead)
212+
.to(torch.bfloat16)
213+
)
214+
u = torch.randn(batch, seqlen, nheads, dstate, dtype=torch.bfloat16, device=DEVICE)
215+
u = torch.nn.functional.rms_norm(u, (dstate,))
216+
g = torch.cumsum(
217+
0.5
218+
* math.log(1 / dhead)
219+
* torch.rand(batch, seqlen, nheads, dtype=torch.float32, device=DEVICE),
220+
dim=1,
221+
)
222+
args = (k, w, u, g, chunk_size)
223+
run_example(helion_gdn_fwd_h, ref_gdn_fwd_h, args)
224+
225+
226+
# %%
227+
# Main Function
228+
# -----------
229+
def main() -> None:
230+
"""
231+
Main entry point that runs the attention kernel test with specific parameters.
232+
"""
233+
test(8, 80, 4096, 256, 64, 128)
234+
235+
236+
if __name__ == "__main__":
237+
main()

0 commit comments

Comments
 (0)