@@ -59,8 +59,15 @@ def _get_block_size(
5959 raise NotImplementedError (f"Unsupported block size type: { type (bd )} " )
6060
6161def _get_block_shape (spec : pallas_core .BlockSpec ):
62- assert spec .block_shape is not None
63- return tuple (_get_block_size (bd ) for bd in spec .block_shape )
62+ if spec .block_shape is None :
63+ raise ValueError ("Block shape must be specified." )
64+
65+ block_shape = tuple (
66+ _get_block_size (bd )
67+ for bd in spec .block_shape
68+ if not (bd is None or isinstance (bd , pl .Squeezed ))
69+ )
70+ return block_shape
6471
6572
6673map_brefs = functools .partial (
@@ -84,18 +91,27 @@ def get_ref_for_slot(
8491 return self .gmem_ref
8592 return self .smem_ref .at [slot ]
8693
87- def compute_gmem_slice (self , grid_indices ) -> tuple [pl .Slice , ...]:
94+ def compute_gmem_slice (self , grid_indices ) -> tuple [pl .Slice | jax . Array , ...]:
8895 index_map = self .spec .index_map
8996 assert index_map is not None
97+ assert self .spec .block_shape is not None
9098 # We don't allow Python scalars here, because they are interpreted
9199 # differently depending on the x32/x64 mode.
92100 assert all (i .dtype == jnp .dtype (jnp .int32 ) for i in grid_indices )
93- sizes = _get_block_shape (self .spec )
101+
102+ def _make_block_slice (block_index : jax .Array , bd : pl .BlockDim | int | None ):
103+ match bd :
104+ case int ():
105+ return pl .Slice (block_index * bd , bd )
106+ case pl .Blocked (block_size ):
107+ return pl .Slice (block_index * block_size , block_size )
108+ case None | pl .Squeezed ():
109+ return block_index
110+ case _:
111+ raise ValueError (f"Unsupported block dimension type: { bd } " )
112+
94113 return tuple (
95- pl .Slice (idx * size , size ) # type: ignore[arg-type]
96- for idx , size in zip (
97- index_map (* grid_indices ), sizes # type: ignore[arg-type]
98- )
114+ map (_make_block_slice , index_map (* grid_indices ), self .spec .block_shape )
99115 )
100116
101117 def copy_in (self , slot , grid_indices , barrier_ref , barrier_slot = None ):
@@ -372,7 +388,8 @@ def loop_body(step, carry):
372388 continue
373389 assert last_store_slices [idx ] is not None
374390 new_store_slices [idx ] = tuple (
375- _Slice (s .start , s .size ) for s in bref .compute_gmem_slice (indices )
391+ _Slice (s .start , s .size ) if isinstance (s , pl .Slice ) else s
392+ for s in bref .compute_gmem_slice (indices )
376393 )
377394 are_same_slices = map (
378395 lambda old , new : old == new ,
@@ -430,11 +447,16 @@ def do_fetch():
430447 fetch_indices = _inc_grid_by_1 (fetch_indices , grid )
431448 fetch_index_levels .append (fetch_indices )
432449
450+ def _init_store_slice (bd ):
451+ if bd is None or isinstance (bd , pl .Squeezed ):
452+ return jnp .array (- 1 , dtype = jnp .int32 )
453+ return _Slice (- 1 , - 1 )
454+
433455 # TODO(justinfu): Only store base pointer instead of all indices.
434456 last_store_slices = [
435457 None
436458 if bref .is_index_invariant
437- else ( _Slice ( - 1 , - 1 ),) * len ( bref .spec .block_shape )
459+ else tuple ( map ( _init_store_slice , bref .spec .block_shape ) )
438460 for bref in out_brefs
439461 ]
440462 last_indices , _ , _ , final_carry = lax .fori_loop (
@@ -690,7 +712,7 @@ def _get_scoped_allocs(*gmem_refs: AbstractRefPytree):
690712 slots = max_concurrent_steps if has_seq_dim else 1
691713 smem_allocs .append (
692714 gpu_core .SMEM (
693- (slots , * spec . block_shape ), # type: ignore
715+ (slots , * _get_block_shape ( spec ) ), # type: ignore
694716 gmem_ref .dtype ,
695717 transforms = getattr (spec , "transforms" , ()),
696718 )
@@ -880,7 +902,8 @@ def compute_loop_body(step, carry):
880902 continue
881903 assert last_store_slices [idx ] is not None
882904 new_store_slices [idx ] = tuple (
883- _Slice (s .start , s .size ) for s in bref .compute_gmem_slice (indices )
905+ _Slice (s .start , s .size ) if isinstance (s , pl .Slice ) else s
906+ for s in bref .compute_gmem_slice (indices )
884907 )
885908 are_same_slices = map (
886909 lambda old , new : old == new ,
@@ -895,11 +918,17 @@ def compute_loop_body(step, carry):
895918 next_indices = _inc_grid_by_1 (indices , grid )
896919 return (next_indices , new_store_slices , next_body_carry )
897920 init_indices = (jnp .asarray (0 , dtype = jnp .int32 ),) * len (grid )
921+
922+ def _init_store_slice (bd ):
923+ if bd is None or isinstance (bd , pl .Squeezed ):
924+ return jnp .array (- 1 , dtype = jnp .int32 )
925+ return _Slice (- 1 , - 1 )
926+
898927 # TODO(justinfu): Only store base pointer instead of all indices.
899928 last_store_slices = [
900929 None
901930 if bref .is_index_invariant
902- else ( _Slice ( - 1 , - 1 ),) * len ( bref .spec .block_shape )
931+ else tuple ( map ( _init_store_slice , bref .spec .block_shape ) )
903932 for bref in flat_out_brefs
904933 ]
905934
0 commit comments