-
Notifications
You must be signed in to change notification settings - Fork 2.2k
feat: Implement Sink Attention #1819
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
My current implementation approach is to add a new top-level interface flash_attn_sink_func, which results in a lot of redundant code. Would you consider adding the sink parameter to the existing interfaces instead? |
|
It's better to add to existing interface instead of duplicating code |
|
Hi @tridao , I have updated the PR addressing your feedback—please kindly take another look when you have time. Thanks again for your guidance! |
|
I plan to add sink to the hopper version in the next PR. |
e34d3ad to
c00f806
Compare
|
It seems |
|
Hi @aoxy, will this work be also ported to work on flash attention 3 (the code on hopper/ subdir)? |
|
@guilhermeleobas , Yes, I also plan to port this work to FlashAttention-3. |
Fix attention with sink combine_attn_seqk_parallel.
|
Hi @tridao , Could you please help review this PR? All tests pass. Would appreciate your feedback. Thank you! |
|
Please |
|
Hi @tridao , Sorry to disturb, but may I kindly ask if there are any updates regarding the review of this PR? Also, I wonder if you are still considering the integration of Sink Attention into Flash Attention v2. Thank you very much! |
|
@aoxy hi, any update on the merge work? We’ve recently run into low training efficiency when doing RL training with gpt-oss because sink attention support is inconsistent between training and inference, and we’d like to know when FA2/FA3 are expected to officially support sink attention. For more information, you can have a look at volcengine/verl#3794 |
I don't know, you could consider using an internal version. |
Fixes #1802
Description
This pull request introduces support for the Sink Attention mechanism (directly following the implementation in GPT-OSS).
This is implemented by optionally incorporating a new
learnable_sinkparameter into the attention functions. This parameter is a tensor of shape(nheads,), providing a learnable bias for each attention head. The sink value is added as an extra logit in the attention score calculation before softmax, allowing the model to learn to "sink" a portion of the attention to a global entry, enhancing model capacity when handling long sequences.Key Changes:
flash_attn_interface.py):Added the
learnable_sinkparameter to all major forward functions (flash_attn_func,flash_attn_qkvpacked_func, etc.), including their variable-length counterparts. The backward pass is updated to compute and return the gradient forlearnable_sink.csrc/):The C++ API and CUDA kernels now accept and integrate the sink values, incorporating them into the softmax computation during the forward pass, and calculating their gradients during the backward pass. A new template parameter
Has_sinkis introduced to conditionally compile the sink-related logic.tests/test_flash_attn.py):Comprehensive tests, including new test cases with
has_learnable_sink=True, have been added to validate both forward and backward passes. The numerical correctness of outputs and gradients is verified against reference implementations. All tests pass.benchmarks/benchmark_flash_attention.py):The benchmark script is updated to add a "Flash2Sink" method for measuring the feature's performance impact.