-
Notifications
You must be signed in to change notification settings - Fork 74
example: gated delta net fwd_h #1119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
v0i0
wants to merge
3
commits into
main
Choose a base branch
from
v0i0/gdn-fwd-h
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+403
−0
Open
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,210 @@ | ||
| """ | ||
| Gated Delta Net Fwd H Kernel | ||
| ============================ | ||
|
|
||
| This code implements a fwd_h kernel as used in gated delta net | ||
| """ | ||
|
|
||
| # %% | ||
| # Imports | ||
| # ------- | ||
| from __future__ import annotations | ||
|
|
||
| import math | ||
| from typing import Callable | ||
|
|
||
| import torch | ||
|
|
||
| import helion | ||
| from helion._testing import DEVICE | ||
| from helion._testing import run_example | ||
| import helion.language as hl | ||
|
|
||
|
|
||
| # %% | ||
| # Helion Kernel Implementation | ||
| # ---------------------------- | ||
| @helion.kernel() | ||
| def helion_gdn_fwd_h( | ||
| k: torch.Tensor, w: torch.Tensor, u: torch.Tensor, g: torch.Tensor, chunk_size: int | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Argument: | ||
| k: (batch, seqlen, nheads, dhead) | ||
| w: (batch, seqlen, nheads, dhead) | ||
| u: (batch, seqlen, nheads, expand_v*dhead) | ||
| g: (batch, seqlen, nheads) | ||
| chunk_size: int | ||
| Return: | ||
| h: (batch, nchunks, nheads, dhead, expand_v*dhead) | ||
| """ | ||
|
|
||
| batch, seqlen, nheads, dhead = k.shape | ||
| dhead = hl.specialize(dhead) | ||
| chunk_size = hl.specialize(chunk_size) | ||
| dstate = u.shape[-1] | ||
|
|
||
| acc_dtype = torch.float32 | ||
| dtype = k.dtype | ||
|
|
||
| nchunks = (seqlen + chunk_size - 1) // chunk_size | ||
| h = torch.empty(batch, nchunks, nheads, dhead, dstate, dtype=dtype, device=k.device) | ||
| block_v = hl.register_block_size(dstate) | ||
|
|
||
| for tile_b, tile_h, tile_v in hl.tile( | ||
| [batch, nheads, dstate], block_size=[1, 1, block_v] | ||
| ): | ||
| b_h = hl.zeros([dhead, tile_v], dtype=acc_dtype) | ||
| for t_i in hl.tile(seqlen, block_size=chunk_size): | ||
| h[tile_b.begin, t_i.id, tile_h.begin, :, tile_v] = b_h.to(dtype) | ||
| b_w = w[tile_b.begin, t_i, tile_h.begin, :] | ||
| c_h = b_h.to(dtype) | ||
| b_v = hl.dot(b_w, c_h, out_dtype=acc_dtype) | ||
| p_v = u[tile_b.begin, t_i, tile_h.begin, tile_v].to(acc_dtype) | ||
| b_v = p_v - b_v | ||
| m_t = t_i.index < seqlen | ||
| t_i_last = min(t_i.begin + chunk_size, seqlen) - 1 | ||
| b_g_last = g[tile_b.begin, t_i_last, tile_h.begin].to(acc_dtype) | ||
| b_g = g[tile_b.begin, t_i, tile_h.begin].to(acc_dtype) | ||
| b_v *= torch.where(m_t, torch.exp(b_g_last - b_g), 0)[:, None] | ||
| b_g_last = torch.exp(b_g_last) | ||
| b_h *= b_g_last | ||
| b_v = b_v.to(dtype) | ||
| p_k = k[tile_b.begin, t_i, tile_h.begin, :] | ||
| b_h = hl.dot(p_k.T, b_v, acc=b_h) | ||
| return h | ||
|
|
||
|
|
||
| def helion_gdn_fwd_h_tb( | ||
| tb_obj: object, | ||
| k: torch.Tensor, | ||
| w: torch.Tensor, | ||
| u: torch.Tensor, | ||
| g: torch.Tensor, | ||
| chunk_size: int, | ||
| ) -> Callable[[], torch.Tensor]: | ||
| """ | ||
| Argument: | ||
| k: (batch, seqlen, nheads, dhead) | ||
| w: (batch, seqlen, nheads, dhead) | ||
| u: (batch, seqlen, nheads, expand_v*dhead) | ||
| g: (batch, seqlen, nheads) | ||
| chunk_size: int | ||
| Return: | ||
| h: (batch, nchunks, nheads, dhead, expand_v*dhead) | ||
| """ | ||
| return lambda: helion_gdn_fwd_h(k, w, u, g, chunk_size) | ||
|
|
||
|
|
||
| # %% | ||
| # Reference Function | ||
| # ------------- | ||
| def ref_gdn_fwd_h( | ||
| k: torch.Tensor, w: torch.Tensor, u: torch.Tensor, g: torch.Tensor, chunk_size: int | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Argument: | ||
| k: (batch, seqlen, nheads, dhead) | ||
| w: (batch, seqlen, nheads, dhead) | ||
| u: (batch, seqlen, nheads, expand_v*dhead) | ||
| g: (batch, seqlen, nheads) | ||
| chunk_size: int | ||
| Return: | ||
| h: (batch, nchunks, nheads, dhead, expand_v*dhead) | ||
| """ | ||
|
|
||
| batch, seqlen, nheads, dhead = k.shape | ||
| expand_v = u.shape[-1] // dhead | ||
| nchunks = (seqlen + chunk_size - 1) // chunk_size | ||
|
|
||
| acc_dtype = torch.float32 | ||
| dtype = k.dtype | ||
|
|
||
| h = torch.empty( | ||
| batch, nchunks, nheads, dhead, expand_v * dhead, dtype=k.dtype, device=k.device | ||
| ) | ||
| b_h = torch.zeros( | ||
| batch, nheads, dhead, expand_v * dhead, dtype=acc_dtype, device=k.device | ||
| ) | ||
|
|
||
| k_c = k.reshape(batch, nchunks, chunk_size, nheads, dhead) | ||
| w_c = w.reshape(batch, nchunks, chunk_size, nheads, dhead) | ||
| u_c = u.reshape(batch, nchunks, chunk_size, nheads, expand_v * dhead) | ||
| g_c = g.reshape(batch, nchunks, chunk_size, nheads) | ||
| for i_t in range(nchunks): | ||
| h[:, i_t, :, :, :] = b_h.to(dtype) | ||
| b_w = w_c[:, i_t, :, :, :].to(acc_dtype) | ||
| c_h = b_h.to(dtype).to(acc_dtype) | ||
| b_v = torch.einsum("bchk,bhkv->bchv", b_w, c_h) | ||
| p_v = u_c[:, i_t, :, :, :].to(acc_dtype) | ||
| b_v = p_v - b_v | ||
| last_idx = min((i_t + 1) * chunk_size, seqlen) - 1 | ||
| m_t = (i_t * chunk_size + torch.arange(0, chunk_size, device=k.device)) < seqlen | ||
| b_g_last = g[:, last_idx, :].to(acc_dtype) | ||
| b_g = g_c[:, i_t, :, :].to(acc_dtype) # batch, chunk, nheads | ||
| b_v *= torch.where( | ||
| m_t.unsqueeze(0).unsqueeze(-1), torch.exp(b_g_last.unsqueeze(1) - b_g), 0 | ||
| ).unsqueeze(-1) | ||
| b_g_last = torch.exp(b_g_last) | ||
| b_h *= b_g_last.unsqueeze(-1).unsqueeze(-1) | ||
| b_v = b_v.to(dtype).to(acc_dtype) | ||
| p_k = k_c[:, i_t, :, :, :].to(acc_dtype) | ||
| b_h += torch.einsum("bchk,bchv->bhkv", p_k, b_v) | ||
| return h | ||
|
|
||
|
|
||
| # %% | ||
| # Testing Function | ||
| # ------------- | ||
| def test( | ||
| batch: int, | ||
| nheads: int, | ||
| seqlen: int, | ||
| chunk_size: int, | ||
| dhead: int, | ||
| dstate: int, | ||
| dtype: torch.dtype = torch.float16, | ||
| ) -> None: | ||
| k = torch.randn(batch, seqlen, nheads, dhead, dtype=torch.bfloat16, device=DEVICE) | ||
| k = torch.nn.functional.rms_norm(k, (dhead,)) | ||
| w = torch.randn( | ||
| batch, | ||
| seqlen // chunk_size, | ||
| chunk_size, | ||
| nheads, | ||
| dhead, | ||
| dtype=torch.float32, | ||
| device=DEVICE, | ||
| ) | ||
| # w = torch.nn.functional.rms_norm(w.to(torch.bfloat16), (dhead,)) | ||
| wu, ws, wv = torch.linalg.svd(w.permute(0, 1, 3, 2, 4), full_matrices=False) | ||
| w = torch.einsum("bnhik,bnhkj->bnhij", wu, wv) | ||
| w = ( | ||
| w.permute(0, 1, 3, 2, 4) | ||
| .reshape(batch, seqlen, nheads, dhead) | ||
| .to(torch.bfloat16) | ||
| ) | ||
| u = torch.randn(batch, seqlen, nheads, dstate, dtype=torch.bfloat16, device=DEVICE) | ||
| u = torch.nn.functional.rms_norm(u, (dstate,)) | ||
| g = torch.cumsum( | ||
| 0.5 | ||
| * math.log(1 / dhead) | ||
| * torch.rand(batch, seqlen, nheads, dtype=torch.float32, device=DEVICE), | ||
| dim=1, | ||
| ) | ||
| args = (k, w, u, g, chunk_size) | ||
| run_example(helion_gdn_fwd_h, ref_gdn_fwd_h, args) | ||
|
|
||
|
|
||
| # %% | ||
| # Main Function | ||
| # ----------- | ||
| def main() -> None: | ||
| """ | ||
| Main entry point that runs the attention kernel test with specific parameters. | ||
| """ | ||
| test(8, 80, 4096, 256, 64, 128) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is it specialized on 1 out of curiosity? Also why are all of the rest using
tile_b.begininstead of justtile_b? Feels a bit ugly 🤔Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems like a good way (to me) to get a grid for purely batched portions of the grid (setting block size 1).
the advantage of using begin then means that the resulting tensors are indexed not sliced, i.e. a[t.begin, :] is 1-d but a[t, :] is a 2-d tensor (with first dim 1). triton empirically is a lot more happy with low-dim tensors, and less typing. you basically get vmap-like syntax where you don't need to worry about batched indices once they're indexed out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm... why not just use
hl.gridthen?I understand what you're saying, but I think the resulting code reads a bit ugly/hacky. For example, I would prefer syntax perhaps like
I think the littering of
tile_b.beginis semantically confusing.Also, in terms of lower-dim tensors, I would prefer to just autotune over that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 i think you had ideas about not needing
x.begin, right?