Skip to content

Conversation

@aoxy
Copy link

@aoxy aoxy commented Aug 18, 2025

Fixes #1802

Description

This pull request introduces support for the Sink Attention mechanism (directly following the implementation in GPT-OSS).

This is implemented by optionally incorporating a new learnable_sink parameter into the attention functions. This parameter is a tensor of shape (nheads,), providing a learnable bias for each attention head. The sink value is added as an extra logit in the attention score calculation before softmax, allowing the model to learn to "sink" a portion of the attention to a global entry, enhancing model capacity when handling long sequences.

Key Changes:

  • Python Interface (flash_attn_interface.py):
    Added the learnable_sink parameter to all major forward functions (flash_attn_func, flash_attn_qkvpacked_func, etc.), including their variable-length counterparts. The backward pass is updated to compute and return the gradient for learnable_sink.
  • C++/CUDA Kernels (csrc/):
    The C++ API and CUDA kernels now accept and integrate the sink values, incorporating them into the softmax computation during the forward pass, and calculating their gradients during the backward pass. A new template parameter Has_sink is introduced to conditionally compile the sink-related logic.
  • Testing (tests/test_flash_attn.py):
    Comprehensive tests, including new test cases with has_learnable_sink=True, have been added to validate both forward and backward passes. The numerical correctness of outputs and gradients is verified against reference implementations. All tests pass.
  • Benchmarks (benchmarks/benchmark_flash_attention.py):
    The benchmark script is updated to add a "Flash2Sink" method for measuring the feature's performance impact.

@aoxy
Copy link
Author

aoxy commented Aug 18, 2025

My current implementation approach is to add a new top-level interface flash_attn_sink_func, which results in a lot of redundant code. Would you consider adding the sink parameter to the existing interfaces instead?

@tridao
Copy link
Member

tridao commented Aug 18, 2025

It's better to add to existing interface instead of duplicating code

@aoxy
Copy link
Author

aoxy commented Aug 22, 2025

Hi @tridao , I have updated the PR addressing your feedback—please kindly take another look when you have time. Thanks again for your guidance!

@aoxy
Copy link
Author

aoxy commented Aug 22, 2025

I plan to add sink to the hopper version in the next PR.

@aoxy aoxy force-pushed the feature/attention_with_sink branch from e34d3ad to c00f806 Compare August 22, 2025 03:50
@gunjunlee
Copy link

gunjunlee commented Aug 23, 2025

It seems flash_attn_with_kvcache with sink produces incorrect results during decoding. Could you check it?

@guilhermeleobas
Copy link

Hi @aoxy, will this work be also ported to work on flash attention 3 (the code on hopper/ subdir)?

@aoxy
Copy link
Author

aoxy commented Sep 10, 2025

@guilhermeleobas , Yes, I also plan to port this work to FlashAttention-3.

@aoxy
Copy link
Author

aoxy commented Sep 19, 2025

Hi @tridao ,

Could you please help review this PR?

All tests pass. Would appreciate your feedback. Thank you!

@Potatooff
Copy link

Please

@aoxy
Copy link
Author

aoxy commented Oct 21, 2025

Hi @tridao ,

Sorry to disturb, but may I kindly ask if there are any updates regarding the review of this PR?

Also, I wonder if you are still considering the integration of Sink Attention into Flash Attention v2.

Thank you very much!

@liuqianchao
Copy link

@aoxy hi, any update on the merge work?

We’ve recently run into low training efficiency when doing RL training with gpt-oss because sink attention support is inconsistent between training and inference, and we’d like to know when FA2/FA3 are expected to officially support sink attention.

For more information, you can have a look at volcengine/verl#3794

@aoxy
Copy link
Author

aoxy commented Dec 2, 2025

@aoxy hi, any update on the merge work?

We’ve recently run into low training efficiency when doing RL training with gpt-oss because sink attention support is inconsistent between training and inference, and we’d like to know when FA2/FA3 are expected to officially support sink attention.

For more information, you can have a look at volcengine/verl#3794

I don't know, you could consider using an internal version.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Any plans to backport additive attention sinks to flash-attn-2?

9 participants