-
Notifications
You must be signed in to change notification settings - Fork 1.5k
[Metal] Add Flash Attention VJP for training #2995
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Notes: I'm working on https://github.com/mlx-node/mlx-node and trying to port some features in trl. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.
mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_vjp_dkv.h
Outdated
Show resolved
Hide resolved
35a7886 to
568ff36
Compare
Implement fused backward pass (VJP) for scaled_dot_product_attention on Metal GPU, enabling efficient training without falling back to unfused attention. - **dQ Kernel** (steel_attention_vjp_dq.h): Computes query gradients - Outer loop over KV blocks, inner accumulation for dQ - Uses log2 domain for numerical stability - **dK/dV Kernel** (steel_attention_vjp_dkv.h): Computes key/value gradients - K-row ownership model eliminates atomic operations - Each simdgroup owns exclusive K rows to prevent races - Optimized path for short sequences (L ≤ 8) - Uses shared memory for efficient reduction - Float32 accumulators for half/bfloat16 precision - Logsumexp caching from forward pass - Proper GQA (grouped query attention) support - Causal mask support - Comprehensive test coverage for all code paths - No gradient support for mask or attention sinks (falls back to unfused) - Requires logsumexp from forward pass (training mode only) - Head dimension D=256 not supported in vector VJP (threadgroup memory) Co-Authored-By: Claude <[email protected]>
568ff36 to
26b5857
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.
| if (use_vector_vjp) { | ||
| if (needs_float32_accumulators) { | ||
| // Use float32 accumulator buffers with the accumulate kernel variant | ||
| // This variant has device float* signature for dK/dV, ensuring correct | ||
| // pointer arithmetic (sizeof(float)=4) instead of sizeof(T)=2 for | ||
| // half/bfloat16 | ||
| array& dk_acc = dk_accum.value(); | ||
| array& dv_acc = dv_accum.value(); | ||
| sdpa_vector_vjp_accumulate_dispatch( | ||
| s, | ||
| d, | ||
| q, | ||
| k, | ||
| v, | ||
| o, | ||
| dO, | ||
| lse, | ||
| d_q, | ||
| dk_acc, | ||
| dv_acc, | ||
| scale_, | ||
| do_causal, | ||
| mask, | ||
| sinks); | ||
|
|
||
| // Convert float32 accumulators to original dtype | ||
| // This uses the standard copy primitive with type conversion | ||
| copy_gpu(dk_acc, d_k, CopyType::General, s); | ||
| copy_gpu(dv_acc, d_v, CopyType::General, s); | ||
|
|
||
| // Add accumulators as temporaries for cleanup | ||
| d.add_temporary(dk_acc, s.index); | ||
| d.add_temporary(dv_acc, s.index); | ||
| } else { | ||
| // Float32: pass dK/dV directly (already zero-initialized above) | ||
| sdpa_vector_vjp_dispatch( | ||
| s, | ||
| d, | ||
| q, | ||
| k, | ||
| v, | ||
| o, | ||
| dO, | ||
| lse, | ||
| d_q, | ||
| d_k, | ||
| d_v, | ||
| scale_, | ||
| do_causal, | ||
| mask, | ||
| sinks); | ||
| } | ||
| } else { | ||
| // Two-kernel STEEL VJP approach for longer sequences | ||
| // This eliminates atomic operations by using separate kernels for dQ and | ||
| // dK/dV: | ||
| // - dQ kernel: loops over KV blocks (grid [NQ, H, B]) | ||
| // - dKV kernel: loops over Q blocks (grid [NK, H, B]) | ||
| // Each kernel owns its output entirely, avoiding race conditions. | ||
|
|
||
| // STEEL VJP kernels (do not support masks/sinks) | ||
| // Safety assertion: use_fallback() should prevent reaching here with | ||
| // masks/sinks | ||
| assert( | ||
| (!has_arr_mask && !has_sinks_) && | ||
| "STEEL VJP called with mask/sinks - use_fallback() should have prevented this"); | ||
|
|
||
| // Runtime guard (active even in release builds where assert is disabled) | ||
| if (has_arr_mask || has_sinks_) { | ||
| throw std::runtime_error( | ||
| "Internal error: STEEL VJP called with masks/sinks. " | ||
| "This indicates a bug in use_fallback() logic."); | ||
| } | ||
|
|
||
| // Dispatch dQ kernel - computes only dQ gradients | ||
| sdpa_steel_vjp_dq_dispatch( | ||
| s, d, q, k, v, o, dO, lse, d_q, scale_, do_causal); | ||
|
|
||
| // Dispatch dKV kernel - computes dK and dV gradients directly | ||
| sdpa_steel_vjp_dkv_dispatch( | ||
| s, d, q, k, v, o, dO, lse, d_k, d_v, scale_, do_causal); | ||
| } | ||
|
|
Copilot
AI
Jan 14, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable use_vector_vjp is hardcoded to false at line 1468, which effectively disables the vector VJP path entirely. The extensive code for vector VJP (lines 1485-1587) becomes dead code. Either remove the dead code or use a proper condition to enable vector VJP when appropriate. The comment explains why it's disabled (logsumexp domain mismatch), but having 100+ lines of unreachable code is a maintainability issue.
| if (use_vector_vjp) { | |
| if (needs_float32_accumulators) { | |
| // Use float32 accumulator buffers with the accumulate kernel variant | |
| // This variant has device float* signature for dK/dV, ensuring correct | |
| // pointer arithmetic (sizeof(float)=4) instead of sizeof(T)=2 for | |
| // half/bfloat16 | |
| array& dk_acc = dk_accum.value(); | |
| array& dv_acc = dv_accum.value(); | |
| sdpa_vector_vjp_accumulate_dispatch( | |
| s, | |
| d, | |
| q, | |
| k, | |
| v, | |
| o, | |
| dO, | |
| lse, | |
| d_q, | |
| dk_acc, | |
| dv_acc, | |
| scale_, | |
| do_causal, | |
| mask, | |
| sinks); | |
| // Convert float32 accumulators to original dtype | |
| // This uses the standard copy primitive with type conversion | |
| copy_gpu(dk_acc, d_k, CopyType::General, s); | |
| copy_gpu(dv_acc, d_v, CopyType::General, s); | |
| // Add accumulators as temporaries for cleanup | |
| d.add_temporary(dk_acc, s.index); | |
| d.add_temporary(dv_acc, s.index); | |
| } else { | |
| // Float32: pass dK/dV directly (already zero-initialized above) | |
| sdpa_vector_vjp_dispatch( | |
| s, | |
| d, | |
| q, | |
| k, | |
| v, | |
| o, | |
| dO, | |
| lse, | |
| d_q, | |
| d_k, | |
| d_v, | |
| scale_, | |
| do_causal, | |
| mask, | |
| sinks); | |
| } | |
| } else { | |
| // Two-kernel STEEL VJP approach for longer sequences | |
| // This eliminates atomic operations by using separate kernels for dQ and | |
| // dK/dV: | |
| // - dQ kernel: loops over KV blocks (grid [NQ, H, B]) | |
| // - dKV kernel: loops over Q blocks (grid [NK, H, B]) | |
| // Each kernel owns its output entirely, avoiding race conditions. | |
| // STEEL VJP kernels (do not support masks/sinks) | |
| // Safety assertion: use_fallback() should prevent reaching here with | |
| // masks/sinks | |
| assert( | |
| (!has_arr_mask && !has_sinks_) && | |
| "STEEL VJP called with mask/sinks - use_fallback() should have prevented this"); | |
| // Runtime guard (active even in release builds where assert is disabled) | |
| if (has_arr_mask || has_sinks_) { | |
| throw std::runtime_error( | |
| "Internal error: STEEL VJP called with masks/sinks. " | |
| "This indicates a bug in use_fallback() logic."); | |
| } | |
| // Dispatch dQ kernel - computes only dQ gradients | |
| sdpa_steel_vjp_dq_dispatch( | |
| s, d, q, k, v, o, dO, lse, d_q, scale_, do_causal); | |
| // Dispatch dKV kernel - computes dK and dV gradients directly | |
| sdpa_steel_vjp_dkv_dispatch( | |
| s, d, q, k, v, o, dO, lse, d_k, d_v, scale_, do_causal); | |
| } | |
| // Two-kernel STEEL VJP approach for longer sequences | |
| // This eliminates atomic operations by using separate kernels for dQ and | |
| // dK/dV: | |
| // - dQ kernel: loops over KV blocks (grid [NQ, H, B]) | |
| // - dKV kernel: loops over Q blocks (grid [NK, H, B]) | |
| // Each kernel owns its output entirely, avoiding race conditions. | |
| // STEEL VJP kernels (do not support masks/sinks) | |
| // Safety assertion: use_fallback() should prevent reaching here with | |
| // masks/sinks | |
| assert( | |
| (!has_arr_mask && !has_sinks_) && | |
| "STEEL VJP called with mask/sinks - use_fallback() should have prevented this"); | |
| // Runtime guard (active even in release builds where assert is disabled) | |
| if (has_arr_mask || has_sinks_) { | |
| throw std::runtime_error( | |
| "Internal error: STEEL VJP called with masks/sinks. " | |
| "This indicates a bug in use_fallback() logic."); | |
| } | |
| // Dispatch dQ kernel - computes only dQ gradients | |
| sdpa_steel_vjp_dq_dispatch( | |
| s, d, q, k, v, o, dO, lse, d_q, scale_, do_causal); | |
| // Dispatch dKV kernel - computes dK and dV gradients directly | |
| sdpa_steel_vjp_dkv_dispatch( | |
| s, d, q, k, v, o, dO, lse, d_k, d_v, scale_, do_causal); |
Summary
Implements fused backward pass (VJP) for
scaled_dot_product_attentionon Metal GPU. This enables efficient gradient computation during training without falling back to unfused (decomposed) attention operations.Changes
New Files
mlx/backend/metal/kernels/sdpa_vector_vjp.h- Vector VJP kernel for short sequences (L ≤ 8)mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_vjp_dq.h- STEEL dQ gradient kernelmlx/backend/metal/kernels/steel/attn/kernels/steel_attention_vjp_dkv.h- STEEL dK/dV gradient kernelModified Files
mlx/backend/metal/scaled_dot_product_attention.cpp- VJP dispatch logic (+840 lines)mlx/fast.cpp/mlx/fast_primitives.h- Logsumexp caching, VJP routingpython/tests/test_fast_sdpa.py- Comprehensive VJP tests (+220 lines)Implementation Notes
Uses a two-kernel approach to avoid atomic operations:
dQ kernel (
steel_attention_vjp_dq.h):dK/dV kernel (
steel_attention_vjp_dkv.h):Vector VJP (
sdpa_vector_vjp.h):Key Features
Limitations
Test Plan
test_sdpa_gradpassestest_sdpa_grad_vector_path- short sequences (L=1,4,7,8)test_sdpa_grad_steel_path- longer sequences (L=16,32,128,256)test_sdpa_grad_head_dims- head dimensions (D=32,64,96,128)test_sdpa_grad_gqa- GQA configurations (4:1, 8:1, 16:1, MHA)test_sdpa_grad_dtypes- float16, bfloat16, float32test_sdpa_grad_edge_cases- L=1, non-power-of-2, large batch, qL≠kvLAll 21 SDPA tests pass (1 skipped for unrelated disabled feature).