@@ -182,26 +182,6 @@ def _inc_grid_by_1(
182182def _in_smem (spec : pallas_core .BlockSpec ) -> bool :
183183 return spec .memory_space in (None , gpu_core .SMEM )
184184
185-
186- # ``pl.Slice`` uses a different pytree encoding, depending on whether the
187- # start/size are static or dynamic. This leads to pytree structure mismatch
188- # in the pipeline body. So, we define a different ``Slice`` class below.
189-
190-
191- @dataclasses .dataclass (frozen = True )
192- class _Slice :
193- start : int | jax .Array
194- size : int | jax .Array
195-
196- def __eq__ (self , other : _Slice ) -> jax .Array : # type: ignore
197- return lax .bitwise_and (self .start == other .start , self .size == other .size )
198-
199-
200- jax .tree_util .register_dataclass (
201- _Slice , data_fields = ["start" , "size" ], meta_fields = []
202- )
203-
204-
205185def _downcast_spec (
206186 spec : gpu_core .BlockSpec | pallas_core .BlockSpec ,
207187) -> gpu_core .BlockSpec :
@@ -357,7 +337,7 @@ def prologue(step, fetch_indices):
357337 # need to fetch more data anyway.
358338 def loop_body (step , carry ):
359339 slot = lax .rem (step , max_concurrent_steps )
360- indices , fetch_index_levels , last_store_slices , prev_body_carry = carry
340+ indices , fetch_index_levels , last_store_indices , prev_body_carry = carry
361341
362342 if barrier_ref is not None :
363343 # Wait for the current GMEM->SMEM copy to complete, if any.
@@ -381,20 +361,17 @@ def loop_body(step, carry):
381361 gpu_primitives .commit_smem ()
382362
383363 # Copy the output from SMEM to GMEM.
384- new_store_slices = last_store_slices [:]
364+ new_store_indices = last_store_indices [:]
385365 for idx , bref in enumerate (out_brefs ):
386366 if bref .is_index_invariant :
387- assert last_store_slices [idx ] is None
367+ assert last_store_indices [idx ] is None
388368 continue
389- assert last_store_slices [idx ] is not None
390- new_store_slices [idx ] = tuple (
391- _Slice (s .start , s .size ) if isinstance (s , pl .Slice ) else s
392- for s in bref .compute_gmem_slice (indices )
393- )
369+ assert last_store_indices [idx ] is not None
370+ new_store_indices [idx ] = bref .spec .index_map (* indices )
394371 are_same_slices = map (
395372 lambda old , new : old == new ,
396- last_store_slices [idx ],
397- new_store_slices [idx ],
373+ last_store_indices [idx ],
374+ new_store_indices [idx ],
398375 )
399376 slices_changed = ~ functools .reduce (lax .bitwise_and , are_same_slices )
400377 is_last_step = step == num_steps - 1
@@ -436,7 +413,7 @@ def do_fetch():
436413 return (
437414 _inc_grid_by_1 (indices , grid ),
438415 next_fetch_indices_levels ,
439- new_store_slices ,
416+ new_store_indices ,
440417 next_body_carry if init_carry is not None else None ,
441418 )
442419
@@ -447,23 +424,18 @@ def do_fetch():
447424 fetch_indices = _inc_grid_by_1 (fetch_indices , grid )
448425 fetch_index_levels .append (fetch_indices )
449426
450- def _init_store_slice (bd ):
451- if bd is None or isinstance (bd , pl .Squeezed ):
452- return jnp .array (- 1 , dtype = jnp .int32 )
453- return _Slice (- 1 , - 1 )
454-
455427 # TODO(justinfu): Only store base pointer instead of all indices.
456- last_store_slices = [
428+ last_store_indices = [
457429 None
458430 if bref .is_index_invariant
459- else tuple ( map ( _init_store_slice , bref .spec .block_shape ) )
431+ else ( jnp . array ( - 1 ),) * len ( bref .spec .block_shape )
460432 for bref in out_brefs
461433 ]
462434 last_indices , _ , _ , final_carry = lax .fori_loop (
463435 0 ,
464436 num_steps ,
465437 loop_body ,
466- (indices , fetch_index_levels , last_store_slices , init_carry ),
438+ (indices , fetch_index_levels , last_store_indices , init_carry ),
467439 )
468440
469441 # Outputs invariant to the sequential axis are never written from inside the
@@ -848,7 +820,7 @@ def compute_block():
848820 needs_epilogue = any (bref .is_index_invariant for bref in smem_out_brefs )
849821
850822 def compute_loop_body (step , carry ):
851- indices , last_store_slices , prev_body_carry = carry
823+ indices , last_store_indices , prev_body_carry = carry
852824 slot = lax .rem (step , max_concurrent_steps )
853825 consumed_slot = lax .rem (step - delay_release , max_concurrent_steps )
854826 # Wait for the current GMEM->SMEM copies to complete.
@@ -895,40 +867,32 @@ def compute_loop_body(step, carry):
895867 if copies_out_in_loop :
896868 gpu_primitives .commit_smem ()
897869
898- new_store_slices = last_store_slices [:]
870+ new_store_indices = last_store_indices [:]
899871 for idx , bref in enumerate (flat_out_brefs ):
900872 if bref .is_index_invariant :
901- assert last_store_slices [idx ] is None
873+ assert last_store_indices [idx ] is None
902874 continue
903- assert last_store_slices [idx ] is not None
904- new_store_slices [idx ] = tuple (
905- _Slice (s .start , s .size ) if isinstance (s , pl .Slice ) else s
906- for s in bref .compute_gmem_slice (indices )
907- )
875+ assert last_store_indices [idx ] is not None
876+ new_store_indices [idx ] = bref .spec .index_map (* indices )
908877 are_same_slices = map (
909878 lambda old , new : old == new ,
910- last_store_slices [idx ],
911- new_store_slices [idx ],
879+ last_store_indices [idx ],
880+ new_store_indices [idx ],
912881 )
913882 slices_changed = ~ functools .reduce (lax .bitwise_and , are_same_slices )
914883 bref .copy_out (_get_slot (slot , not bref .is_index_invariant ),
915884 indices ,
916885 predicate = slices_changed )
917886 gpu_primitives .commit_smem_to_gmem_group ()
918887 next_indices = _inc_grid_by_1 (indices , grid )
919- return (next_indices , new_store_slices , next_body_carry )
888+ return (next_indices , new_store_indices , next_body_carry )
920889 init_indices = (jnp .asarray (0 , dtype = jnp .int32 ),) * len (grid )
921890
922- def _init_store_slice (bd ):
923- if bd is None or isinstance (bd , pl .Squeezed ):
924- return jnp .array (- 1 , dtype = jnp .int32 )
925- return _Slice (- 1 , - 1 )
926-
927891 # TODO(justinfu): Only store base pointer instead of all indices.
928- last_store_slices = [
892+ last_store_indices = [
929893 None
930894 if bref .is_index_invariant
931- else tuple ( map ( _init_store_slice , bref .spec .block_shape ) )
895+ else ( jnp . array ( - 1 ),) * len ( bref .spec .block_shape )
932896 for bref in flat_out_brefs
933897 ]
934898
@@ -939,7 +903,7 @@ def pipeline_callback(user_init_carry):
939903 if last_indices is not None :
940904 raise ValueError (
941905 "Cannot call pipeline more than once in `compute_context`" )
942- init_loop_carry = (init_indices , last_store_slices , user_init_carry )
906+ init_loop_carry = (init_indices , last_store_indices , user_init_carry )
943907 last_indices , _ , final_body_carry = lax .fori_loop (0 ,
944908 num_steps ,
945909 compute_loop_body ,
@@ -952,7 +916,7 @@ def pipeline_callback(user_init_carry):
952916 assert compute_context is None
953917 last_indices , _ , _ = lax .fori_loop (
954918 0 , num_steps , compute_loop_body ,
955- (init_indices , last_store_slices , None )
919+ (init_indices , last_store_indices , None )
956920 )
957921
958922 # Handle index_invariant outputs after the loop. They are not
0 commit comments