Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions include/sgl_kernel_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,14 @@ void merge_state(
void merge_state_v2(
at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged);
void cutlass_mla_decode(
torch::Tensor const& out,
torch::Tensor const& q_nope_and_q_pe,
torch::Tensor const& kv_c_and_k_pe_cache,
torch::Tensor const& seq_lens,
torch::Tensor const& page_table,
torch::Tensor const& workspace,
torch::Tensor& out,
const torch::Tensor q_nope,
const torch::Tensor q_pe,
const torch::Tensor kv_c_and_k_pe_cache,
const torch::Tensor seq_lens,
const torch::Tensor page_table,
torch::Tensor & workspace,
double sm_scale,
int64_t num_kv_splits = -1);
int64_t cutlass_mla_get_workspace_size(
int64_t max_seq_len, int64_t num_batches, int64_t sm_count = 0, int64_t num_kv_splits = -1);
Expand Down
60 changes: 37 additions & 23 deletions python/sgl_kernel/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,66 +52,80 @@ def merge_state_v2(


def cutlass_mla_decode(
q_nope_and_q_pe: torch.Tensor,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
seq_lens: torch.Tensor,
page_table: torch.Tensor,
workspace: torch.Tensor,
num_kv_splits: int = -1,
sm_scale: float,
num_kv_splits: int = 1
) -> torch.Tensor:
assert (
q_nope_and_q_pe.ndim == 3
), f"q_nope_and_q_pe must be a 3D tensor, but got {q_nope_and_q_pe.ndim}"
assert q_nope.ndim == 3, f"q_nope must be a 3D tensor, but got {q_nope.ndim}"
assert q_pe.ndim == 3, f"q_pe must be a 3D tensor, but got {q_pe.ndim}"
assert (
kv_c_and_k_pe_cache.ndim == 3
), f"kv_c_and_k_pe_cache must be a 3D tensor, but got {kv_c_and_k_pe_cache.ndim}"
B_q, H, D_q = q_nope_and_q_pe.shape

B_q, H, D_q_nope = q_nope.shape
B_q_2, H_2, D_q_pe = q_pe.shape
assert (B_q == B_q_2) and (H == H_2)

_, PAGE_SIZE, D_ckv = kv_c_and_k_pe_cache.shape

D_latent = 512
# TODO: currently only support D_latent=128 and D_rope=64 (earlier it was D_latent=512)
# 128 is not functionally correct dim for MLA, once underlying head_dims>256 is ready in cutlass, we can update this
D_latent = 128 # 512
D_rope = 64
assert D_q == D_ckv and D_q == D_latent + D_rope, (
f"D_q must be equal to D_ckv and D_q must be equal to D_latent + D_rope, "
f"but got D_q = {D_q}, D_ckv = {D_ckv}, D_latent = {D_latent}, D_rope = {D_rope}"
)
# TODO: removing this because we are supporting only same dim size for kv cache and output for now
# will enable this when different dim sizes are supported in mha_fwd kernel
# assert D_q_nope == D_latent
# assert D_q_pe == D_rope
# assert D_ckv == D_latent + D_rope

MAX_HEADS = 128
assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}"
if H < MAX_HEADS:
q_nope_and_q_pe_padded = q_nope_and_q_pe.new_empty((B_q, MAX_HEADS, D_q))
q_nope_and_q_pe_padded[:, :H] = q_nope_and_q_pe
q_nope_and_q_pe = q_nope_and_q_pe_padded
q_nope_padded = q_nope.new_empty((B_q, MAX_HEADS, D_q_nope))
q_nope_padded[:, :H] = q_nope
q_nope = q_nope_padded

q_pe_padded = q_pe.new_empty((B_q, MAX_HEADS, D_q_pe))
q_pe_padded[:, :H] = q_pe
q_pe = q_pe_padded

assert len(page_table.shape) == 2
B_block_table, block_num = page_table.shape
assert B_block_table == B_q
assert block_num > 0, f"block num must be greater than 0, got {block_num}"
assert block_num % (128 / PAGE_SIZE) == 0

# TODO(kaixih@nvidia): support fp8
assert q_nope_and_q_pe.dtype in (
# TODO: support fp8
assert q_nope.dtype in (
torch.float16,
torch.bfloat16,
), f"q_nope_and_q_pe.dtype needs to be fp16 or bf16 but got {q_nope_and_q_pe.dtype}."
assert kv_c_and_k_pe_cache.dtype == q_nope_and_q_pe.dtype, (
f"kv_c_and_k_pe_cache.dtype needs to be the same as q_nope_and_q_pe.dtype, "
f"but got {kv_c_and_k_pe_cache.dtype}."
)
), f"q_nope.dtype needs to be fp16 or bf16 but got {q_nope.dtype}."
assert q_nope.dtype == q_pe.dtype == kv_c_and_k_pe_cache.dtype
assert (
seq_lens.dtype == torch.int32
), f"seq_lens.dtype needs to be int32 but got {seq_lens.dtype}."
assert (
page_table.dtype == torch.int32
), f"page_table.dtype needs to be int32 but got {page_table.dtype}."

out = q_nope_and_q_pe.new_empty((B_q, MAX_HEADS, D_latent))
# TODO: currently creating output last dim as D_latent + D_rope to make it compatible with cutlass mla decode kernel
# once mla decode kernel supports different dim sizes, we can change it to D_latent only
out = q_nope.new_empty((B_q, MAX_HEADS, D_latent+D_rope))

torch.ops.sgl_kernel.cutlass_mla_decode.default(
out,
q_nope_and_q_pe,
q_nope,
q_pe,
kv_c_and_k_pe_cache,
seq_lens,
page_table,
workspace,
sm_scale,
num_kv_splits,
)
return out[:, :H].contiguous()
Expand Down
Loading