@@ -59,8 +59,15 @@ def _get_block_size(
5959 raise NotImplementedError (f"Unsupported block size type: { type (bd )} " )
6060
6161def _get_block_shape (spec : pallas_core .BlockSpec ):
62- assert spec .block_shape is not None
63- return tuple (_get_block_size (bd ) for bd in spec .block_shape )
62+ if spec .block_shape is None :
63+ raise ValueError ("Block shape must be specified." )
64+
65+ block_shape = tuple (
66+ _get_block_size (bd )
67+ for bd in spec .block_shape
68+ if not (bd is None or isinstance (bd , pl .Squeezed ))
69+ )
70+ return block_shape
6471
6572
6673map_brefs = functools .partial (
@@ -84,18 +91,27 @@ def get_ref_for_slot(
8491 return self .gmem_ref
8592 return self .smem_ref .at [slot ]
8693
87- def compute_gmem_slice (self , grid_indices ) -> tuple [pl .Slice , ...]:
94+ def compute_gmem_slice (self , grid_indices ) -> tuple [pl .Slice | jax . Array , ...]:
8895 index_map = self .spec .index_map
8996 assert index_map is not None
97+ assert self .spec .block_shape is not None
9098 # We don't allow Python scalars here, because they are interpreted
9199 # differently depending on the x32/x64 mode.
92100 assert all (i .dtype == jnp .dtype (jnp .int32 ) for i in grid_indices )
93- sizes = _get_block_shape (self .spec )
101+
102+ def _make_block_slice (block_index : jax .Array , bd : pl .BlockDim | int | None ):
103+ match bd :
104+ case int ():
105+ return pl .Slice (block_index * bd , bd )
106+ case pl .Blocked (block_size ):
107+ return pl .Slice (block_index * block_size , block_size )
108+ case None | pl .Squeezed ():
109+ return block_index
110+ case _:
111+ raise ValueError (f"Unsupported block dimension type: { bd } " )
112+
94113 return tuple (
95- pl .Slice (idx * size , size ) # type: ignore[arg-type]
96- for idx , size in zip (
97- index_map (* grid_indices ), sizes # type: ignore[arg-type]
98- )
114+ map (_make_block_slice , index_map (* grid_indices ), self .spec .block_shape )
99115 )
100116
101117 def copy_in (self , slot , grid_indices , barrier_ref , barrier_slot = None ):
@@ -166,26 +182,6 @@ def _inc_grid_by_1(
166182def _in_smem (spec : pallas_core .BlockSpec ) -> bool :
167183 return spec .memory_space in (None , gpu_core .SMEM )
168184
169-
170- # ``pl.Slice`` uses a different pytree encoding, depending on whether the
171- # start/size are static or dynamic. This leads to pytree structure mismatch
172- # in the pipeline body. So, we define a different ``Slice`` class below.
173-
174-
175- @dataclasses .dataclass (frozen = True )
176- class _Slice :
177- start : int | jax .Array
178- size : int | jax .Array
179-
180- def __eq__ (self , other : _Slice ) -> jax .Array : # type: ignore
181- return lax .bitwise_and (self .start == other .start , self .size == other .size )
182-
183-
184- jax .tree_util .register_dataclass (
185- _Slice , data_fields = ["start" , "size" ], meta_fields = []
186- )
187-
188-
189185def _downcast_spec (
190186 spec : gpu_core .BlockSpec | pallas_core .BlockSpec ,
191187) -> gpu_core .BlockSpec :
@@ -341,7 +337,7 @@ def prologue(step, fetch_indices):
341337 # need to fetch more data anyway.
342338 def loop_body (step , carry ):
343339 slot = lax .rem (step , max_concurrent_steps )
344- indices , fetch_index_levels , last_store_slices , prev_body_carry = carry
340+ indices , fetch_index_levels , last_store_indices , prev_body_carry = carry
345341
346342 if barrier_ref is not None :
347343 # Wait for the current GMEM->SMEM copy to complete, if any.
@@ -365,19 +361,17 @@ def loop_body(step, carry):
365361 gpu_primitives .commit_smem ()
366362
367363 # Copy the output from SMEM to GMEM.
368- new_store_slices = last_store_slices [:]
364+ new_store_indices = last_store_indices [:]
369365 for idx , bref in enumerate (out_brefs ):
370366 if bref .is_index_invariant :
371- assert last_store_slices [idx ] is None
367+ assert last_store_indices [idx ] is None
372368 continue
373- assert last_store_slices [idx ] is not None
374- new_store_slices [idx ] = tuple (
375- _Slice (s .start , s .size ) for s in bref .compute_gmem_slice (indices )
376- )
369+ assert last_store_indices [idx ] is not None
370+ new_store_indices [idx ] = bref .spec .index_map (* indices )
377371 are_same_slices = map (
378372 lambda old , new : old == new ,
379- last_store_slices [idx ],
380- new_store_slices [idx ],
373+ last_store_indices [idx ],
374+ new_store_indices [idx ],
381375 )
382376 slices_changed = ~ functools .reduce (lax .bitwise_and , are_same_slices )
383377 is_last_step = step == num_steps - 1
@@ -419,7 +413,7 @@ def do_fetch():
419413 return (
420414 _inc_grid_by_1 (indices , grid ),
421415 next_fetch_indices_levels ,
422- new_store_slices ,
416+ new_store_indices ,
423417 next_body_carry if init_carry is not None else None ,
424418 )
425419
@@ -431,17 +425,17 @@ def do_fetch():
431425 fetch_index_levels .append (fetch_indices )
432426
433427 # TODO(justinfu): Only store base pointer instead of all indices.
434- last_store_slices = [
428+ last_store_indices = [
435429 None
436430 if bref .is_index_invariant
437- else (_Slice ( - 1 , - 1 ),) * len (bref .spec .block_shape )
431+ else (jnp . array ( - 1 ),) * len (bref .spec .block_shape )
438432 for bref in out_brefs
439433 ]
440434 last_indices , _ , _ , final_carry = lax .fori_loop (
441435 0 ,
442436 num_steps ,
443437 loop_body ,
444- (indices , fetch_index_levels , last_store_slices , init_carry ),
438+ (indices , fetch_index_levels , last_store_indices , init_carry ),
445439 )
446440
447441 # Outputs invariant to the sequential axis are never written from inside the
@@ -690,7 +684,7 @@ def _get_scoped_allocs(*gmem_refs: AbstractRefPytree):
690684 slots = max_concurrent_steps if has_seq_dim else 1
691685 smem_allocs .append (
692686 gpu_core .SMEM (
693- (slots , * spec . block_shape ), # type: ignore
687+ (slots , * _get_block_shape ( spec ) ), # type: ignore
694688 gmem_ref .dtype ,
695689 transforms = getattr (spec , "transforms" , ()),
696690 )
@@ -826,7 +820,7 @@ def compute_block():
826820 needs_epilogue = any (bref .is_index_invariant for bref in smem_out_brefs )
827821
828822 def compute_loop_body (step , carry ):
829- indices , last_store_slices , prev_body_carry = carry
823+ indices , last_store_indices , prev_body_carry = carry
830824 slot = lax .rem (step , max_concurrent_steps )
831825 consumed_slot = lax .rem (step - delay_release , max_concurrent_steps )
832826 # Wait for the current GMEM->SMEM copies to complete.
@@ -873,33 +867,32 @@ def compute_loop_body(step, carry):
873867 if copies_out_in_loop :
874868 gpu_primitives .commit_smem ()
875869
876- new_store_slices = last_store_slices [:]
870+ new_store_indices = last_store_indices [:]
877871 for idx , bref in enumerate (flat_out_brefs ):
878872 if bref .is_index_invariant :
879- assert last_store_slices [idx ] is None
873+ assert last_store_indices [idx ] is None
880874 continue
881- assert last_store_slices [idx ] is not None
882- new_store_slices [idx ] = tuple (
883- _Slice (s .start , s .size ) for s in bref .compute_gmem_slice (indices )
884- )
875+ assert last_store_indices [idx ] is not None
876+ new_store_indices [idx ] = bref .spec .index_map (* indices )
885877 are_same_slices = map (
886878 lambda old , new : old == new ,
887- last_store_slices [idx ],
888- new_store_slices [idx ],
879+ last_store_indices [idx ],
880+ new_store_indices [idx ],
889881 )
890882 slices_changed = ~ functools .reduce (lax .bitwise_and , are_same_slices )
891883 bref .copy_out (_get_slot (slot , not bref .is_index_invariant ),
892884 indices ,
893885 predicate = slices_changed )
894886 gpu_primitives .commit_smem_to_gmem_group ()
895887 next_indices = _inc_grid_by_1 (indices , grid )
896- return (next_indices , new_store_slices , next_body_carry )
888+ return (next_indices , new_store_indices , next_body_carry )
897889 init_indices = (jnp .asarray (0 , dtype = jnp .int32 ),) * len (grid )
890+
898891 # TODO(justinfu): Only store base pointer instead of all indices.
899- last_store_slices = [
892+ last_store_indices = [
900893 None
901894 if bref .is_index_invariant
902- else (_Slice ( - 1 , - 1 ),) * len (bref .spec .block_shape )
895+ else (jnp . array ( - 1 ),) * len (bref .spec .block_shape )
903896 for bref in flat_out_brefs
904897 ]
905898
@@ -910,7 +903,7 @@ def pipeline_callback(user_init_carry):
910903 if last_indices is not None :
911904 raise ValueError (
912905 "Cannot call pipeline more than once in `compute_context`" )
913- init_loop_carry = (init_indices , last_store_slices , user_init_carry )
906+ init_loop_carry = (init_indices , last_store_indices , user_init_carry )
914907 last_indices , _ , final_body_carry = lax .fori_loop (0 ,
915908 num_steps ,
916909 compute_loop_body ,
@@ -923,7 +916,7 @@ def pipeline_callback(user_init_carry):
923916 assert compute_context is None
924917 last_indices , _ , _ = lax .fori_loop (
925918 0 , num_steps , compute_loop_body ,
926- (init_indices , last_store_slices , None )
919+ (init_indices , last_store_indices , None )
927920 )
928921
929922 # Handle index_invariant outputs after the loop. They are not
0 commit comments