Skip to content

Conversation

@v0i0
Copy link
Contributor

@v0i0 v0i0 commented Nov 12, 2025

  (Batch, Heads, SeqLen, ChunkSize, Dhead, ExpandV)    eager-gbps    compile-gbps    fla-gbps    tilelang-gbps    helion_helion_gdn_fwd_h_tb-gbps
---------------------------------------------------  ------------  --------------  ----------  ---------------  ---------------------------------
                           (1, 6, 1024, 64, 256, 2)       3.92615         7.40171     373.228          261.995                            439.151
                                            average       3.92615         7.40171     373.228          261.995                            439.151

@v0i0 v0i0 requested a review from yf225 November 12, 2025 21:48
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 12, 2025
@tzj-fxz
Copy link

tzj-fxz commented Nov 13, 2025

Hi! I have tested this kernel with some large-seqlen configs (with 32k seqlen and different heads, typically (1, 32, 32768, 128, 1)). The autotune process takes a lot of time and finally gives the best config, but the accuracy is zero.

Module          FLOP    % Total
-----------  -------  ---------
Global       68.719B    100.00%
 - aten.bmm  68.719B    100.00%
Module          FLOP    % Total
-----------  -------  ---------
Global       68.719B    100.00%
 - aten.bmm  68.719B    100.00%
Module          FLOP    % Total
-----------  -------  ---------
Global       68.719B    100.00%
 - aten.bmm  68.719B    100.00%
[tritonbench] Output result csv to /tmp/tmp9y222nql.csv
(Batch, Heads, SeqLen, ChunkSize, Dhead, ExpandV);eager-tflops;eager-gbps;eager-latency;compile-speedup;compile-accuracy;compile-tflops;compile-gbps;compile-latency;helion_helion_gdn_fwd_h_tb-speedup;helion_helion_gdn_fwd_h_tb-accuracy;helion_helion_gdn_fwd_h_tb-tflops;helion_helion_gdn_fwd_h_tb-gbps;helion_helion_gdn_fwd_h_tb-latency
(1, 32, 32768, 64, 128, 1);0.6079966534761014;7.7558166855777095;113.02607727050781;0.814987395509034;0.0;0.49550960909469655;6.320892840624485;138.6844482421875;142.42771271025643;0.0;86.59557269009149;1104.6432307268751;0.7935680150985718
average;0.6079966534761014;7.7558166855777095;;0.814987395509034;0.0;0.49550960909469655;6.320892840624485;;142.42771271025643;0.0;86.59557269009149;1104.6432307268751;
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
TritonBench accuracy check failed with Helion kernel config: @helion.kernel(config=helion.Config(block_sizes=[32], indexing=['tensor_descriptor', 'pointer', 'pointer', 'tensor_descriptor', 'pointer', 'tensor_descriptor'], l2_groupings=[4], load_eviction_policies=['last', 'last', 'first', '', 'first'], loop_orders=[[0, 1, 2]], num_stages=2, num_warps=4, pid_type='persistent_blocked', range_flattens=[False, None], range_multi_buffers=[False, False], range_num_stages=[1, 2], range_unroll_factors=[1, 1], range_warp_specializes=[]), static_shapes=True)
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

@v0i0
Copy link
Contributor Author

v0i0 commented Nov 13, 2025

Hey @tzj-fxz thank you for checking it out! Initialization & error checking for this kernel seems tricky, did you see similar issues for the other implementations in tritonbench / other shapes? I wonder if a more robust way to generate the inputs would be to capture them out of FLA or something else.

@tzj-fxz
Copy link

tzj-fxz commented Nov 14, 2025

Yes. I have tested several seqlen (4k, 8k, 16k, 32k) with other configs frozen. There is always the same error message showing that the accuracy check failed. BTW 1k and 2k cases run successfully. :)

Hey @tzj-fxz thank you for checking it out! Initialization & error checking for this kernel seems tricky, did you see similar issues for the other implementations in tritonbench / other shapes? I wonder if a more robust way to generate the inputs would be to capture them out of FLA or something else.

@v0i0
Copy link
Contributor Author

v0i0 commented Nov 14, 2025

I see. I suspect this is just an issue with our chosen reference implementation, here is a run with lots of kernel. note how FLA and helion match exactly in their accuracy field.

  (Batch, Heads, SeqLen, ChunkSize, Dhead, ExpandV)    eager-gbps    compile-accuracy    compile-gbps    fla-accuracy    fla-gbps    tilelang-accuracy    tilelang-gbps    helion_helion_gdn_fwd_h_tb-accuracy    helion_helion_gdn_fwd_h_tb-gbps
---------------------------------------------------  ------------  ------------------  --------------  --------------  ----------  -------------------  ---------------  -------------------------------------  ---------------------------------
                           (1, 6, 1024, 64, 256, 2)       3.97343            1                7.32136             1       371.538                    0          261.161                                    1                              435.653
                           (1, 6, 2048, 64, 256, 2)       3.68888            0               10.1702              1       407.355                    0          281.227                                    1                              489.109
                           (1, 6, 4096, 64, 256, 2)       3.72543            0                8.47115             0       429.852                    0          293.247                                    0                              527.357
                          (16, 6, 1024, 64, 256, 2)      29.2867             0               40.8267              1       590.805                    0          628.67                                     1                              710.052
                          (16, 6, 2048, 64, 256, 2)      29.3871             0               17.4041              0       574.055                    0          648.2                                      0                              515.333
                          (16, 6, 4096, 64, 256, 2)      29.3519             0                9.8477              0       558.972                    0          649.403                                    0                              518.932
                                            average      16.5689             0.166667        15.6735              0.5     488.763                    0          460.318                                    0.5                            532.739



def helion_gdn_fwd_h(
k: torch.Tensor, w: torch.Tensor, u: torch.Tensor, g: torch.Tensor, chunk_size: int
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, why not do the reshape within the helion kernel itself?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you do the reshape in the hl.kernel, helion will check that it is correct under dynamic shapes / or with static shapes, but i've had trouble to make the check work (might be a bug or i need to assert something)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll look into this issue

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

opened a PR to fix it: #1146

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if the seqlen is not divisible by chunk size?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might give that a shot now that the above fix is in :-) i feel like it should be fairly natural to express with helion

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sustcsonglin that worked, and its fairly neat even!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great! and really neat

Copy link
Contributor

@yf225 yf225 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks great, thanks @v0i0 ! maybe also add a unit test to test_examples.py?

block_v = hl.register_block_size(dstate)

for tile_b, tile_h, tile_v in hl.tile(
[batch, nheads, dstate], block_size=[1, 1, block_v]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is it specialized on 1 out of curiosity? Also why are all of the rest using tile_b.begin instead of just tile_b? Feels a bit ugly 🤔

Copy link
Contributor Author

@v0i0 v0i0 Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems like a good way (to me) to get a grid for purely batched portions of the grid (setting block size 1).
the advantage of using begin then means that the resulting tensors are indexed not sliced, i.e. a[t.begin, :] is 1-d but a[t, :] is a 2-d tensor (with first dim 1). triton empirically is a lot more happy with low-dim tensors, and less typing. you basically get vmap-like syntax where you don't need to worry about batched indices once they're indexed out.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm... why not just use hl.grid then?

I understand what you're saying, but I think the resulting code reads a bit ugly/hacky. For example, I would prefer syntax perhaps like

for idx_b, idx_h, tile_v in hl.tile([batch, nheads, dstate], block_size=[0, 0, block_v])

I think the littering of tile_b.begin is semantically confusing.

Also, in terms of lower-dim tensors, I would prefer to just autotune over that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 i think you had ideas about not needing x.begin, right?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants