Skip to content

Commit 335b350

Browse files
committed
fix
1 parent 4c7aab3 commit 335b350

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

python/sgl_jax/srt/layers/attention/flash_attn_kernel/flash_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def ref_ragged_paged_attention(
115115
if custom_mask != None:
116116
raise ValueError(f"use causal mask, custom_mask is not None")
117117
else:
118-
if custom_mask == None or custom_mask.size() < jnp.cumsum(kv_lens)[-1]:
118+
if custom_mask == None or custom_mask.size < jnp.cumsum(kv_lens)[-1]:
119119
raise ValueError(
120120
f"use custom_mask, custom_mask length must larger than total kv length"
121121
)

python/sgl_jax/test/test_flashattention.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ def align_to_size(l, size, value=0):
247247
attention_backend = FlashAttentionBackend(
248248
num_heads, num_kv_heads, head_dim, page_size=page_size, mesh=mesh
249249
)
250+
print(f"!!!!!!!! {causal=}")
250251
if not causal:
251252
forward_mode = ForwardMode.TARGET_VERIFY
252253
custom_mask = create_custom_mask(lens)
@@ -307,8 +308,10 @@ def align_to_size(l, size, value=0):
307308
cache_loc=cache_loc,
308309
extend_prefix_lens=extend_prefix_lens,
309310
extend_seq_lens=extend_seq_lens,
311+
spec_info=spec_info,
310312
)
311313
fb.attn_backend.forward_metadata = attention_backend.get_forward_metadata(mwb)
314+
312315
return fb, q, k, v
313316

314317

0 commit comments

Comments
 (0)