Skip to content

Commit 9af7216

Browse files
Rifur13Google-ML-Automation
authored andcommitted
[Pallas MGPU] Simplify how we keep track of the current output slices. Keeping track of the full store slices is unnecessary because the slice size doesn’t change, we really only care about the the start of the slices.
PiperOrigin-RevId: 846273629
1 parent b8cd917 commit 9af7216

File tree

1 file changed

+23
-59
lines changed

1 file changed

+23
-59
lines changed

jax/_src/pallas/mosaic_gpu/pipeline.py

Lines changed: 23 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -182,26 +182,6 @@ def _inc_grid_by_1(
182182
def _in_smem(spec: pallas_core.BlockSpec) -> bool:
183183
return spec.memory_space in (None, gpu_core.SMEM)
184184

185-
186-
# ``pl.Slice`` uses a different pytree encoding, depending on whether the
187-
# start/size are static or dynamic. This leads to pytree structure mismatch
188-
# in the pipeline body. So, we define a different ``Slice`` class below.
189-
190-
191-
@dataclasses.dataclass(frozen=True)
192-
class _Slice:
193-
start: int | jax.Array
194-
size: int | jax.Array
195-
196-
def __eq__(self, other: _Slice) -> jax.Array: # type: ignore
197-
return lax.bitwise_and(self.start == other.start, self.size == other.size)
198-
199-
200-
jax.tree_util.register_dataclass(
201-
_Slice, data_fields=["start", "size"], meta_fields=[]
202-
)
203-
204-
205185
def _downcast_spec(
206186
spec: gpu_core.BlockSpec | pallas_core.BlockSpec,
207187
) -> gpu_core.BlockSpec:
@@ -357,7 +337,7 @@ def prologue(step, fetch_indices):
357337
# need to fetch more data anyway.
358338
def loop_body(step, carry):
359339
slot = lax.rem(step, max_concurrent_steps)
360-
indices, fetch_index_levels, last_store_slices, prev_body_carry = carry
340+
indices, fetch_index_levels, last_store_indices, prev_body_carry = carry
361341

362342
if barrier_ref is not None:
363343
# Wait for the current GMEM->SMEM copy to complete, if any.
@@ -381,20 +361,17 @@ def loop_body(step, carry):
381361
gpu_primitives.commit_smem()
382362

383363
# Copy the output from SMEM to GMEM.
384-
new_store_slices = last_store_slices[:]
364+
new_store_indices = last_store_indices[:]
385365
for idx, bref in enumerate(out_brefs):
386366
if bref.is_index_invariant:
387-
assert last_store_slices[idx] is None
367+
assert last_store_indices[idx] is None
388368
continue
389-
assert last_store_slices[idx] is not None
390-
new_store_slices[idx] = tuple(
391-
_Slice(s.start, s.size) if isinstance(s, pl.Slice) else s
392-
for s in bref.compute_gmem_slice(indices)
393-
)
369+
assert last_store_indices[idx] is not None
370+
new_store_indices[idx] = bref.spec.index_map(*indices)
394371
are_same_slices = map(
395372
lambda old, new: old == new,
396-
last_store_slices[idx],
397-
new_store_slices[idx],
373+
last_store_indices[idx],
374+
new_store_indices[idx],
398375
)
399376
slices_changed = ~functools.reduce(lax.bitwise_and, are_same_slices)
400377
is_last_step = step == num_steps - 1
@@ -436,7 +413,7 @@ def do_fetch():
436413
return (
437414
_inc_grid_by_1(indices, grid),
438415
next_fetch_indices_levels,
439-
new_store_slices,
416+
new_store_indices,
440417
next_body_carry if init_carry is not None else None,
441418
)
442419

@@ -447,23 +424,18 @@ def do_fetch():
447424
fetch_indices = _inc_grid_by_1(fetch_indices, grid)
448425
fetch_index_levels.append(fetch_indices)
449426

450-
def _init_store_slice(bd):
451-
if bd is None or isinstance(bd, pl.Squeezed):
452-
return jnp.array(-1, dtype=jnp.int32)
453-
return _Slice(-1, -1)
454-
455427
# TODO(justinfu): Only store base pointer instead of all indices.
456-
last_store_slices = [
428+
last_store_indices = [
457429
None
458430
if bref.is_index_invariant
459-
else tuple(map(_init_store_slice, bref.spec.block_shape))
431+
else (jnp.array(-1),) * len(bref.spec.block_shape)
460432
for bref in out_brefs
461433
]
462434
last_indices, _, _, final_carry = lax.fori_loop(
463435
0,
464436
num_steps,
465437
loop_body,
466-
(indices, fetch_index_levels, last_store_slices, init_carry),
438+
(indices, fetch_index_levels, last_store_indices, init_carry),
467439
)
468440

469441
# Outputs invariant to the sequential axis are never written from inside the
@@ -848,7 +820,7 @@ def compute_block():
848820
needs_epilogue = any(bref.is_index_invariant for bref in smem_out_brefs)
849821

850822
def compute_loop_body(step, carry):
851-
indices, last_store_slices, prev_body_carry = carry
823+
indices, last_store_indices, prev_body_carry = carry
852824
slot = lax.rem(step, max_concurrent_steps)
853825
consumed_slot = lax.rem(step - delay_release, max_concurrent_steps)
854826
# Wait for the current GMEM->SMEM copies to complete.
@@ -895,40 +867,32 @@ def compute_loop_body(step, carry):
895867
if copies_out_in_loop:
896868
gpu_primitives.commit_smem()
897869

898-
new_store_slices = last_store_slices[:]
870+
new_store_indices = last_store_indices[:]
899871
for idx, bref in enumerate(flat_out_brefs):
900872
if bref.is_index_invariant:
901-
assert last_store_slices[idx] is None
873+
assert last_store_indices[idx] is None
902874
continue
903-
assert last_store_slices[idx] is not None
904-
new_store_slices[idx] = tuple(
905-
_Slice(s.start, s.size) if isinstance(s, pl.Slice) else s
906-
for s in bref.compute_gmem_slice(indices)
907-
)
875+
assert last_store_indices[idx] is not None
876+
new_store_indices[idx] = bref.spec.index_map(*indices)
908877
are_same_slices = map(
909878
lambda old, new: old == new,
910-
last_store_slices[idx],
911-
new_store_slices[idx],
879+
last_store_indices[idx],
880+
new_store_indices[idx],
912881
)
913882
slices_changed = ~functools.reduce(lax.bitwise_and, are_same_slices)
914883
bref.copy_out(_get_slot(slot, not bref.is_index_invariant),
915884
indices,
916885
predicate=slices_changed)
917886
gpu_primitives.commit_smem_to_gmem_group()
918887
next_indices = _inc_grid_by_1(indices, grid)
919-
return (next_indices, new_store_slices, next_body_carry)
888+
return (next_indices, new_store_indices, next_body_carry)
920889
init_indices = (jnp.asarray(0, dtype=jnp.int32),) * len(grid)
921890

922-
def _init_store_slice(bd):
923-
if bd is None or isinstance(bd, pl.Squeezed):
924-
return jnp.array(-1, dtype=jnp.int32)
925-
return _Slice(-1, -1)
926-
927891
# TODO(justinfu): Only store base pointer instead of all indices.
928-
last_store_slices = [
892+
last_store_indices = [
929893
None
930894
if bref.is_index_invariant
931-
else tuple(map(_init_store_slice, bref.spec.block_shape))
895+
else (jnp.array(-1),) * len(bref.spec.block_shape)
932896
for bref in flat_out_brefs
933897
]
934898

@@ -939,7 +903,7 @@ def pipeline_callback(user_init_carry):
939903
if last_indices is not None:
940904
raise ValueError(
941905
"Cannot call pipeline more than once in `compute_context`")
942-
init_loop_carry = (init_indices, last_store_slices, user_init_carry)
906+
init_loop_carry = (init_indices, last_store_indices, user_init_carry)
943907
last_indices, _, final_body_carry = lax.fori_loop(0,
944908
num_steps,
945909
compute_loop_body,
@@ -952,7 +916,7 @@ def pipeline_callback(user_init_carry):
952916
assert compute_context is None
953917
last_indices, _, _ = lax.fori_loop(
954918
0, num_steps, compute_loop_body,
955-
(init_indices, last_store_slices, None)
919+
(init_indices, last_store_indices, None)
956920
)
957921

958922
# Handle index_invariant outputs after the loop. They are not

0 commit comments

Comments
 (0)