Skip to content

HKUSTDial/flash-sparse-attention

flash-algo

English | 简体中文

Flash-Sparse-Attention is a high-performance trainable sparse attention implementation that combines Flash Attention's memory efficiency with sparse computation for handling extremely long sequences in Transformer models.

Key Features

Note

Support for arbitrary mask and bias shapes is available in this branch. The current main branch no longer maintains that feature set.

Supported Features

  • Forward and backward passes for dense attention, sparse attention, and gated attention
  • Regular batched inputs and varlen inputs
  • Causal attention and local window attention
  • Arbitrary combinations of Q and KV sequence lengths, with head dimensions up to 256
  • Grouped Query Attention and Multi Query Attention
  • Sparse softmax threshold control
  • Gated attention with gate inputs and configurable gating sparsity
  • Flex Local Window Attention with per-head arbitrary window sizes and local ranges
  • Split-KV for workload balancing in forward and decode workloads
  • Split-QO for workload balancing in backward workloads
  • Fused Quant for low-precision computation on hardware without native FP8 support
  • Top-k gather KV-cache decode

Features We Aim to Support

  • Paged Attention
  • KV-Cache Manager
  • TLE backend support
  • Gluon backend support

Installation

Requirements

  • Linux: Ubuntu 22.04 or later
  • Device: GPU, XPU, NPU, or PPU
  • Python: 3.9 or later
  • PyTorch: 2.5.1 or later
  • Triton: 3.6.0 or later

Install

Install from PyPI:

pip install flash-sparse-attn

To install from source:

git clone https://github.com/flash-algo/flash-sparse-attn.git
cd flash-sparse-attn
pip install .

Install via HuggingFace Kernel

You can also load the kernels directly from HuggingFace Kernel without installing the package:

from kernels import get_kernel

fsa = get_kernel("JingzeShi/flash-sparse-attn", version=1, trust_remote_code=True)

out = fsa.flash_dense_attn_func(q, k, v, is_causal=True)
out = fsa.flash_sparse_attn_func(q, k, v, is_causal=True, softmax_threshold=0.01)
out = fsa.flash_gated_attn_func(q, k, v, alpha, delta, is_causal=True)

Requires pip install kernels.

Quick Start

Basic Usage

Below are examples for the three common attention variants:

import torch
from flash_sparse_attn.ops.triton.interface import (
    flash_dense_attn_func,
    flash_sparse_attn_func,
    flash_gated_attn_func,
)

dtype = torch.bfloat16
device = torch.device("cuda")
batch_size, seqlen_q, seqlen_k, num_heads, num_kv_heads, head_dim = 2, 1024, 1024, 8, 2, 64

query = torch.randn(batch_size, seqlen_q, num_heads, head_dim, dtype=dtype, device=device)
key = torch.randn(batch_size, seqlen_k, num_kv_heads, head_dim, dtype=dtype, device=device)
value = torch.randn(batch_size, seqlen_k, num_kv_heads, head_dim, dtype=dtype, device=device)

Dense Attention

Use this when you do not need explicit sparsification but still want an efficient attention kernel.

output_dense = flash_dense_attn_func(
    query=query,
    key=key,
    value=value,
    is_causal=True,
)

print(output_dense.shape)

Sparse Attention

Use this when you want to skip low-contribution attention weights through softmax_threshold and reduce effective compute on long sequences.

output_sparse = flash_sparse_attn_func(
    query=query,
    key=key,
    value=value,
    is_causal=True,
    softmax_threshold=1.0,
)

print(output_sparse.shape)

Gated Attention

Use this when you need explicit gating signals for sparse attention. alpha controls query-side gating and delta controls key-side gating.

alpha = torch.randn(batch_size, num_heads, seqlen_q, device=device, dtype=dtype)
delta = torch.randn(batch_size, num_kv_heads, seqlen_k, device=device, dtype=dtype)

output_gated = flash_gated_attn_func(
    query=query,
    key=key,
    value=value,
    alpha=alpha,
    delta=delta,
    is_causal=True,
    softmax_threshold=1.0,
    gate_threshold=1.0,
)

print(output_gated.shape)

Performance

The following benchmarks cover forward, backward, and decode workloads. They include dense, sparse, and gated implementations, with FlashAttention used as the baseline.

NVIDIA GPU

A100

Forward Performance

Attention forward speed, head dim 128, a100

Backward Performance

Attention backward speed, head dim 128, a100

Decode Performance

Attention decode speed, head dim 128, a100

H20

Forward Performance

Attention forward speed, head dim 128, h20-3e

Backward Performance

Attention backward speed, head dim 128, h20-3e

Decode Performance

Attention decode speed, head dim 128, h20-3e

RTX PRO 6000

Forward Performance

Attention forward speed, head dim 128, rtx pro 6000

Backward Performance

Attention backward speed, head dim 128, rtx pro 6000

Decode Performance

Attention decode speed, head dim 128, rtx pro 6000

T-Head PPU

ZW810E

Forward Performance

Attention forward speed, head dim 128, zw810e

Backward Performance

Attention backward speed, head dim 128, ppuzw810e

Decode Performance

Attention decode speed, head dim 128, ppuzw810e

Benchmarking

Benchmark scripts are located under tests, covering forward, backward, and decoding performance.

Forward Performance

python tests/benchmark_forward.py

Backward Performance

python tests/benchmark_backward.py

Decode Performance

python tests/benchmark_decode.py

Citation

If you use FSA in your research, please cite:

@misc{shi2025trainabledynamicmasksparse,
      title={Trainable Dynamic Mask Sparse Attention},
      author={Jingze Shi and Yifan Wu and Bingheng Wu and Yiran Peng and Liangdong Wang and Guang Liu and Yuyu Luo},
      year={2025},
      eprint={2508.02124},
      archivePrefix={arXiv},
      primaryClass={cs.AI},
      url={https://arxiv.org/abs/2508.02124},
}

Acknowledgments

This project builds upon and integrates several excellent works:

We thank the open-source community for its contributions to efficient Transformer implementations. 🤗