@@ -254,16 +254,15 @@ def _ragged_paged_attention_kernel(
254254 sem_ids_ref , # [3] (bq_sem_idx, bkv_sem_idx, bo_sem_idx)
255255 bo_ids_ref , # [4] (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx)
256256 bkv_update_ids_ref , # [6] (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz)
257+ custom_mask_ref , # (flatten_total_kv_len,),
257258 # Input
258259 q_hbm_ref , # [actual_num_kv_heads, padded_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
259260 kv_hbm_ref , # [padded_num_tokens, num_kv_heads_x2 // kv_packing, kv_packing, head_dim] - Fused KV with interleaved [K1,V1,K2,V2,...]
260261 kv_cache_fused_hbm_ref , # [total_num_pages, page_size, num_kv_heads_interleaved // kv_packing, kv_packing, head_dim]
261- custom_mask_ref , # (flatten_total_kv_len,), int8, dma not support bool type
262262 # Output
263263 o_hbm_ref , # [actual_num_kv_heads, max_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
264264 updated_kv_cache_fused_hbm_ref , # [total_num_pages, page_size, num_kv_heads_interleaved // kv_packing, kv_packing, head_dim]
265265 # Scratch
266- bkvmask_ref , # [2, bq_sz, bkv_sz]
267266 bkv_fused_x2_ref , # [2, bkv_sz, num_kv_heads_interleaved // kv_packing, kv_packing, head_dim]
268267 bq_x2_ref , # [2, actual_num_kv_heads, bq_sz, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
269268 bo_x2_ref , # [2, actual_num_kv_heads, bq_sz, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
@@ -324,54 +323,19 @@ def _ragged_paged_attention_kernel(
324323 q_len = q_end - q_start
325324 kv_len = kv_lens_ref [seq_idx ]
326325
326+ cur_seq_mask_start = cu_seq_mask_lens [seq_idx ]
327+ cur_seq_mask_len = q_len * kv_len
328+ cur_seq_mask = custom_mask_ref [
329+ cur_seq_mask_start : cur_seq_mask_start + cur_seq_mask_len
330+ ].reshape (q_len , kv_len )
331+
327332 def _async_copy (src , dst , sem , wait ):
328333 cp = pltpu .make_async_copy (src , dst , sem )
329334 if wait :
330335 cp .wait ()
331336 else :
332337 cp .start ()
333338
334- def _fetch_mask (seq_idx , bq_idx , bkvmask_idx , bkvmask_sem_idx , * , wait = False ):
335- sem = sems .at [4 , bkvmask_sem_idx ]
336- assert sem .dtype == sems .dtype , f"######## { sem .dtype = } { sems .dtype = } "
337- kvmask_fused_vmem_ref = bkvmask_ref .at [bkvmask_sem_idx ]
338-
339- kv_len = kv_lens_ref [seq_idx ]
340- mask_len = kv_len
341- mask_start = bkvmask_idx * bkv_sz
342- mask_left = mask_len - mask_start
343- load_kv_sz = jnp .minimum (bkv_sz , mask_left )
344-
345- q_len_start = cu_q_lens_ref [seq_idx ] + bq_idx * bq_sz
346- q_end = cu_q_lens_ref [seq_idx + 1 ]
347- load_q_sz = jnp .minimum (bq_sz , q_end - q_len_start )
348-
349- cur_seq_mask_start = cu_seq_mask_lens [seq_idx ]
350- cur_bq_mask_start = cur_seq_mask_start + bq_idx * bq_sz * kv_len
351-
352- # Whether using custom mask, depends on causal args
353- # flatten mask: [TTTTTTFFFFTFTTFFFTTFFTTTTTFFFFTTTTTTFT,FFFTFFTFTTTTTFTFFFFFTTFTTTTFTFTTFTTT]
354- # ^kv_start ^mask_start
355- # <--load_sz-->
356-
357- def loop_body (i , _ ):
358- start = cur_bq_mask_start + i * kv_len + mask_start
359- start = jnp .minimum (custom_mask_ref .shape [0 ], start )
360- _async_copy (
361- custom_mask_ref .at [pl .ds (start , load_kv_sz )],
362- kvmask_fused_vmem_ref .at [i , pl .ds (0 , load_kv_sz )],
363- sem ,
364- wait ,
365- )
366-
367- lax .fori_loop (
368- 0 ,
369- load_q_sz ,
370- loop_body ,
371- None ,
372- unroll = False ,
373- )
374-
375339 def _fetch_bkv (seq_idx , bkv_idx , bkv_sem_idx , * , wait = False ):
376340 sem = sems .at [0 , bkv_sem_idx ]
377341 kv_fused_vmem_ref = bkv_fused_x2_ref .at [bkv_sem_idx ]
@@ -505,12 +469,6 @@ def _send_bo(seq_idx, bo_idx, bo_sem_idx, *, wait=False):
505469 wait ,
506470 )
507471
508- def start_fetch_mask (seq_idx , bq_idx , bkvmask_idx , bkvmask_sem_idx ):
509- return _fetch_mask (seq_idx , bq_idx , bkvmask_idx , bkvmask_sem_idx )
510-
511- def wait_fetch_mask (seq_idx , bq_idx , bkvmask_idx , bkvmask_sem_idx ):
512- return _fetch_mask (seq_idx , bq_idx , bkvmask_idx , bkvmask_sem_idx , wait = True )
513-
514472 def start_fetch_bkv (seq_idx , bkv_idx , bkv_sem_idx ):
515473 return _fetch_bkv (seq_idx , bkv_idx , bkv_sem_idx )
516474
@@ -691,12 +649,6 @@ def prefetch_next_bkv():
691649 sem_ids_ref [1 ] = next_bkv_sem_idx
692650 start_fetch_bkv (next_seq_idx , next_bkv_idx , next_bkv_sem_idx )
693651
694- @pl .when (causal == 0 )
695- def _ ():
696- start_fetch_mask (
697- next_seq_idx , bq_idx , next_bkv_idx , next_bkv_sem_idx
698- )
699-
700652 # Wait for cur bq if not ready yet
701653 @pl .when (bkv_idx == 0 )
702654 def wait_cur_bq ():
@@ -705,11 +657,6 @@ def wait_cur_bq():
705657 # Wait for cur bkv
706658 offset , update_sz = wait_fetch_bkv (seq_idx , bkv_idx , bkv_sem_idx )
707659
708- # Wait for kv mask if not use causal mask
709- @pl .when (causal == 0 )
710- def _ ():
711- wait_fetch_mask (seq_idx , bq_idx , bkv_idx , bkv_sem_idx )
712-
713660 # Start updating bkv to kv cache if applicable.
714661 # Only needed in first bq loop.
715662 @pl .when (jnp .logical_and (update_sz > 0 , bq_idx == 0 ))
@@ -746,10 +693,23 @@ def batch_prepare_queries():
746693 return jnp .stack (q_heads , axis = 0 )
747694
748695 def load_mask ():
749- mask = bkvmask_ref [bkv_sem_idx , :actual_bq_sz ]
750- # assert False, f'{mask.shape=} {jnp.zeros((actual_num_kv_heads, actual_bq_sz*num_q_heads_per_kv_head, mask.shape[-1])).shape=}'
696+ bq_mask_start = bq_idx * bq_sz
697+ bq_mask_end = (bq_idx + 1 ) * bq_sz
698+ bq_mask_offset = lax .select (
699+ bq_mask_end - q_len < 0 , bq_sz , bq_mask_end - q_len
700+ )
701+
702+ bkv_mask_start = bkv_idx * bkv_sz
703+ bkv_mask_end = (bkv_idx + 1 ) * bkv_sz
704+ bkv_mask_offset = lax .select (
705+ bkv_mask_end - kv_len < 0 , bkv_sz , bkv_mask_end - kv_len
706+ )
707+
708+ cur_bq_bkv_mask = cur_seq_mask [
709+ bq_mask_start :bq_mask_offset , bkv_mask_start :bkv_mask_offset
710+ ]
751711 num_q_heads_per_kv_head_mask = jnp .concat (
752- [mask ] * num_q_heads_per_kv_head
712+ [cur_bq_bkv_mask ] * num_q_heads_per_kv_head
753713 )
754714 num_kv_heads_mask = jnp .concat (
755715 [
@@ -759,12 +719,12 @@ def load_mask():
759719 ]
760720 * actual_num_kv_heads
761721 )
762- return num_kv_heads_mask
722+ # convert custom mask from int8 to bool
723+ return num_kv_heads_mask > 0
763724
764725 # Load batched data
765726 k_batch , v_batch = batch_load_all_heads_kv ()
766727 q_batch = batch_prepare_queries ()
767- custom_mask = load_mask ()
768728
769729 def flash_attention (q_batch , k_batch , v_batch ):
770730 q_batch_f32 = q_batch .astype (jnp .float32 )
@@ -799,12 +759,8 @@ def flash_attention(q_batch, k_batch, v_batch):
799759 k_span = bkv_idx * bkv_sz + lax .broadcasted_iota (
800760 jnp .int32 , s .shape , 2
801761 )
802- # convert custom_mask from int8 to bool
803- mask = lax .select (
804- causal == 0 ,
805- custom_mask .astype (jnp .bool ),
806- q_span < k_span ,
807- )
762+ mask = lax .cond (causal == 1 , lambda : q_span < k_span , load_mask )
763+
808764 if sliding_window is not None :
809765 mask = jnp .logical_or (mask , q_span - sliding_window >= k_span )
810766
@@ -1206,7 +1162,7 @@ def ragged_paged_attention(
12061162 cu_q_lens : jax .Array , # i32[padded_batch_size + 1]
12071163 cu_kv_lens : jax .Array , # i32[padded_batch_size + 1]
12081164 distribution : jax .Array , # i32[3]
1209- custom_mask : jax .Array , # if causal is True, custom_mask shape is [patten_total_kv_len], else [0]
1165+ custom_mask : jax .Array , # if causal is True, custom_mask shape is [patten_total_kv_len], else None
12101166 * ,
12111167 causal : int = 1 , # 1: True, 0: False
12121168 sm_scale : float = 1.0 ,
@@ -1343,17 +1299,14 @@ def ragged_paged_attention(
13431299 # fix bug: XLA layout ({0}) does not match Mosaic layout ({0:T(128)}) for an operand of shape s32[0]
13441300 custom_mask = jnp .empty ((1 , 128 ), dtype = jnp .int8 )
13451301 else :
1346- assert (
1347- custom_mask .dtype != jnp .bool
1348- ), f"custom_mask bool dtype is not supported, use int32 instead. 0: False, 1: True"
1302+ custom_mask = custom_mask .astype (jnp .int8 )
13491303
13501304 grid = (distribution [2 ],)
13511305
13521306 in_specs = [
13531307 pl .BlockSpec (memory_space = pltpu .ANY ), # q
13541308 pl .BlockSpec (memory_space = pltpu .ANY ), # kv_fused
13551309 pl .BlockSpec (memory_space = pltpu .ANY ), # kv_cache_fused
1356- pl .BlockSpec (memory_space = pltpu .ANY ), # custom_mask
13571310 ]
13581311
13591312 out_specs = [
@@ -1366,11 +1319,6 @@ def ragged_paged_attention(
13661319 kv_cache_fused_processed .dtype ,
13671320 )
13681321
1369- bkvmask_double_buf = pltpu .VMEM (
1370- (2 , bq_sz , bkv_sz ),
1371- jnp .bool ,
1372- )
1373-
13741322 bq_double_buf = pltpu .VMEM (
13751323 (2 , actual_num_kv_heads , bq_sz , * q .shape [2 :]),
13761324 q .dtype ,
@@ -1390,7 +1338,6 @@ def ragged_paged_attention(
13901338 )
13911339
13921340 scratch_shapes = [
1393- bkvmask_double_buf , # Double buffering for fused kv mask block with head interleaving.
13941341 bkv_fused_double_buf , # Double buffering for fused kv block with head interleaving.
13951342 bq_double_buf , # Double buffering for q block.
13961343 bo_double_buf , # Double buffering for output block.
@@ -1415,6 +1362,7 @@ def ragged_paged_attention(
14151362 jnp .full ((4 ,), - 1 , jnp .int32 ),
14161363 # (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz)
14171364 jnp .full ((6 ,), - 1 , jnp .int32 ),
1365+ custom_mask ,
14181366 )
14191367 scope_name = f"RPA-bq_{ bq_sz } -bkvp_{ bkv_p } -p_{ page_size } "
14201368 kernel = jax .named_scope (scope_name )(
@@ -1454,8 +1402,8 @@ def ragged_paged_attention(
14541402 ),
14551403 ],
14561404 input_output_aliases = {
1457- 9 : 0 , # q input -> q output
1458- 11 : 1 , # kv_cache_fused input -> updated kv_cache_fused output
1405+ 10 : 0 , # q input -> q output
1406+ 12 : 1 , # kv_cache_fused input -> updated kv_cache_fused output
14591407 },
14601408 name = scope_name ,
14611409 )
@@ -1466,7 +1414,6 @@ def ragged_paged_attention(
14661414 q ,
14671415 kv ,
14681416 kv_cache_fused_processed ,
1469- custom_mask ,
14701417 )
14711418 return (
14721419 prepare_outputs (output , actual_num_q_heads_per_kv_head , actual_head_dim ),
0 commit comments