@@ -59,8 +59,15 @@ def _get_block_size(
5959 raise NotImplementedError (f"Unsupported block size type: { type (bd )} " )
6060
6161def _get_block_shape (spec : pallas_core .BlockSpec ):
62- assert spec .block_shape is not None
63- return tuple (_get_block_size (bd ) for bd in spec .block_shape )
62+ if spec .block_shape is None :
63+ raise ValueError ("Block shape must be specified." )
64+
65+ block_shape = tuple (
66+ _get_block_size (bd )
67+ for bd in spec .block_shape
68+ if not (bd is None or isinstance (bd , pl .Squeezed ))
69+ )
70+ return block_shape
6471
6572
6673map_brefs = functools .partial (
@@ -84,18 +91,27 @@ def get_ref_for_slot(
8491 return self .gmem_ref
8592 return self .smem_ref .at [slot ]
8693
87- def compute_gmem_slice (self , grid_indices ) -> tuple [pl .Slice , ...]:
94+ def compute_gmem_slice (self , grid_indices ) -> tuple [pl .Slice | jax . Array , ...]:
8895 index_map = self .spec .index_map
8996 assert index_map is not None
9097 # We don't allow Python scalars here, because they are interpreted
9198 # differently depending on the x32/x64 mode.
9299 assert all (i .dtype == jnp .dtype (jnp .int32 ) for i in grid_indices )
93- sizes = _get_block_shape (self .spec )
100+
101+ def _make_block_slice (block_index : jax .Array , bd : pl .BlockDim ):
102+ match bd :
103+ case int ():
104+ return pl .Slice (block_index * bd , bd )
105+ case pl .Blocked (block_size ):
106+ return pl .Slice (block_index * block_size , block_size )
107+ case None | pl .Squeezed ():
108+ return block_index
109+ case _:
110+ raise ValueError (f"Unsupported block dimension type: { bd } " )
111+
94112 return tuple (
95- pl .Slice (idx * size , size ) # type: ignore[arg-type]
96- for idx , size in zip (
97- index_map (* grid_indices ), sizes # type: ignore[arg-type]
98- )
113+ _make_block_slice (bd , idx )
114+ for bd , idx in zip (index_map (* grid_indices ), self .spec .block_shape )
99115 )
100116
101117 def copy_in (self , slot , grid_indices , barrier_ref , barrier_slot = None ):
@@ -177,6 +193,10 @@ class _Slice:
177193 start : int | jax .Array
178194 size : int | jax .Array
179195
196+ @classmethod
197+ def from_val (cls , s : pl .Slice | jax .Array ):
198+ return cls (s .start , s .size ) if isinstance (s , pl .Slice ) else cls (s , 1 )
199+
180200 def __eq__ (self , other : _Slice ) -> jax .Array : # type: ignore
181201 return lax .bitwise_and (self .start == other .start , self .size == other .size )
182202
@@ -372,7 +392,7 @@ def loop_body(step, carry):
372392 continue
373393 assert last_store_slices [idx ] is not None
374394 new_store_slices [idx ] = tuple (
375- _Slice ( s . start , s . size ) for s in bref .compute_gmem_slice (indices )
395+ _Slice . from_val ( s ) for s in bref .compute_gmem_slice (indices )
376396 )
377397 are_same_slices = map (
378398 lambda old , new : old == new ,
@@ -690,7 +710,7 @@ def _get_scoped_allocs(*gmem_refs: AbstractRefPytree):
690710 slots = max_concurrent_steps if has_seq_dim else 1
691711 smem_allocs .append (
692712 gpu_core .SMEM (
693- (slots , * spec . block_shape ), # type: ignore
713+ (slots , * _get_block_shape ( spec ) ), # type: ignore
694714 gmem_ref .dtype ,
695715 transforms = getattr (spec , "transforms" , ()),
696716 )
@@ -880,7 +900,7 @@ def compute_loop_body(step, carry):
880900 continue
881901 assert last_store_slices [idx ] is not None
882902 new_store_slices [idx ] = tuple (
883- _Slice ( s . start , s . size ) for s in bref .compute_gmem_slice (indices )
903+ _Slice . from_val ( s ) for s in bref .compute_gmem_slice (indices )
884904 )
885905 are_same_slices = map (
886906 lambda old , new : old == new ,
0 commit comments