-
Notifications
You must be signed in to change notification settings - Fork 72
example: gated delta net fwd_h #1119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Hi! I have tested this kernel with some large-seqlen configs (with 32k seqlen and different heads, typically (1, 32, 32768, 128, 1)). The autotune process takes a lot of time and finally gives the best config, but the accuracy is zero. |
|
Hey @tzj-fxz thank you for checking it out! Initialization & error checking for this kernel seems tricky, did you see similar issues for the other implementations in tritonbench / other shapes? I wonder if a more robust way to generate the inputs would be to capture them out of FLA or something else. |
|
Yes. I have tested several seqlen (4k, 8k, 16k, 32k) with other configs frozen. There is always the same error message showing that the accuracy check failed. BTW 1k and 2k cases run successfully. :)
|
|
I see. I suspect this is just an issue with our chosen reference implementation, here is a run with lots of kernel. note how FLA and helion match exactly in their accuracy field. |
examples/gdn_fwd_h.py
Outdated
|
|
||
|
|
||
| def helion_gdn_fwd_h( | ||
| k: torch.Tensor, w: torch.Tensor, u: torch.Tensor, g: torch.Tensor, chunk_size: int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw, why not do the reshape within the helion kernel itself?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you do the reshape in the hl.kernel, helion will check that it is correct under dynamic shapes / or with static shapes, but i've had trouble to make the check work (might be a bug or i need to assert something)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll look into this issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
opened a PR to fix it: #1146
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if the seqlen is not divisible by chunk size?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might give that a shot now that the above fix is in :-) i feel like it should be fairly natural to express with helion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sustcsonglin that worked, and its fairly neat even!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks great! and really neat
yf225
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it looks great, thanks @v0i0 ! maybe also add a unit test to test_examples.py?
| block_v = hl.register_block_size(dstate) | ||
|
|
||
| for tile_b, tile_h, tile_v in hl.tile( | ||
| [batch, nheads, dstate], block_size=[1, 1, block_v] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is it specialized on 1 out of curiosity? Also why are all of the rest using tile_b.begin instead of just tile_b? Feels a bit ugly 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems like a good way (to me) to get a grid for purely batched portions of the grid (setting block size 1).
the advantage of using begin then means that the resulting tensors are indexed not sliced, i.e. a[t.begin, :] is 1-d but a[t, :] is a 2-d tensor (with first dim 1). triton empirically is a lot more happy with low-dim tensors, and less typing. you basically get vmap-like syntax where you don't need to worry about batched indices once they're indexed out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm... why not just use hl.grid then?
I understand what you're saying, but I think the resulting code reads a bit ugly/hacky. For example, I would prefer syntax perhaps like
for idx_b, idx_h, tile_v in hl.tile([batch, nheads, dstate], block_size=[0, 0, block_v])
I think the littering of tile_b.begin is semantically confusing.
Also, in terms of lower-dim tensors, I would prefer to just autotune over that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 i think you had ideas about not needing x.begin, right?
Uh oh!
There was an error while loading. Please reload this page.