Skip to content

Commit c23b497

Browse files
Rifur13Google-ML-Automation
authored andcommitted
[Pallas MGPU] Simplify how we keep track of the current output slices. Keeping track of the full store slices is unnecessary because the slice size doesn’t change, we really only care about the the start of the slices.
PiperOrigin-RevId: 845422049
1 parent ba024e3 commit c23b497

File tree

2 files changed

+110
-56
lines changed

2 files changed

+110
-56
lines changed

jax/_src/pallas/mosaic_gpu/pipeline.py

Lines changed: 49 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,15 @@ def _get_block_size(
5959
raise NotImplementedError(f"Unsupported block size type: {type(bd)}")
6060

6161
def _get_block_shape(spec: pallas_core.BlockSpec):
62-
assert spec.block_shape is not None
63-
return tuple(_get_block_size(bd) for bd in spec.block_shape)
62+
if spec.block_shape is None:
63+
raise ValueError("Block shape must be specified.")
64+
65+
block_shape = tuple(
66+
_get_block_size(bd)
67+
for bd in spec.block_shape
68+
if not (bd is None or isinstance(bd, pl.Squeezed))
69+
)
70+
return block_shape
6471

6572

6673
map_brefs = functools.partial(
@@ -84,18 +91,27 @@ def get_ref_for_slot(
8491
return self.gmem_ref
8592
return self.smem_ref.at[slot]
8693

87-
def compute_gmem_slice(self, grid_indices) -> tuple[pl.Slice, ...]:
94+
def compute_gmem_slice(self, grid_indices) -> tuple[pl.Slice | jax.Array, ...]:
8895
index_map = self.spec.index_map
8996
assert index_map is not None
97+
assert self.spec.block_shape is not None
9098
# We don't allow Python scalars here, because they are interpreted
9199
# differently depending on the x32/x64 mode.
92100
assert all(i.dtype == jnp.dtype(jnp.int32) for i in grid_indices)
93-
sizes = _get_block_shape(self.spec)
101+
102+
def _make_block_slice(block_index: jax.Array, bd: pl.BlockDim | int | None):
103+
match bd:
104+
case int():
105+
return pl.Slice(block_index * bd, bd)
106+
case pl.Blocked(block_size):
107+
return pl.Slice(block_index * block_size, block_size)
108+
case None | pl.Squeezed():
109+
return block_index
110+
case _:
111+
raise ValueError(f"Unsupported block dimension type: {bd}")
112+
94113
return tuple(
95-
pl.Slice(idx * size, size) # type: ignore[arg-type]
96-
for idx, size in zip(
97-
index_map(*grid_indices), sizes # type: ignore[arg-type]
98-
)
114+
map(_make_block_slice, index_map(*grid_indices), self.spec.block_shape)
99115
)
100116

101117
def copy_in(self, slot, grid_indices, barrier_ref, barrier_slot=None):
@@ -166,26 +182,6 @@ def _inc_grid_by_1(
166182
def _in_smem(spec: pallas_core.BlockSpec) -> bool:
167183
return spec.memory_space in (None, gpu_core.SMEM)
168184

169-
170-
# ``pl.Slice`` uses a different pytree encoding, depending on whether the
171-
# start/size are static or dynamic. This leads to pytree structure mismatch
172-
# in the pipeline body. So, we define a different ``Slice`` class below.
173-
174-
175-
@dataclasses.dataclass(frozen=True)
176-
class _Slice:
177-
start: int | jax.Array
178-
size: int | jax.Array
179-
180-
def __eq__(self, other: _Slice) -> jax.Array: # type: ignore
181-
return lax.bitwise_and(self.start == other.start, self.size == other.size)
182-
183-
184-
jax.tree_util.register_dataclass(
185-
_Slice, data_fields=["start", "size"], meta_fields=[]
186-
)
187-
188-
189185
def _downcast_spec(
190186
spec: gpu_core.BlockSpec | pallas_core.BlockSpec,
191187
) -> gpu_core.BlockSpec:
@@ -341,7 +337,7 @@ def prologue(step, fetch_indices):
341337
# need to fetch more data anyway.
342338
def loop_body(step, carry):
343339
slot = lax.rem(step, max_concurrent_steps)
344-
indices, fetch_index_levels, last_store_slices, prev_body_carry = carry
340+
indices, fetch_index_levels, last_store_indices, prev_body_carry = carry
345341

346342
if barrier_ref is not None:
347343
# Wait for the current GMEM->SMEM copy to complete, if any.
@@ -365,19 +361,17 @@ def loop_body(step, carry):
365361
gpu_primitives.commit_smem()
366362

367363
# Copy the output from SMEM to GMEM.
368-
new_store_slices = last_store_slices[:]
364+
new_store_indices = last_store_indices[:]
369365
for idx, bref in enumerate(out_brefs):
370366
if bref.is_index_invariant:
371-
assert last_store_slices[idx] is None
367+
assert last_store_indices[idx] is None
372368
continue
373-
assert last_store_slices[idx] is not None
374-
new_store_slices[idx] = tuple(
375-
_Slice(s.start, s.size) for s in bref.compute_gmem_slice(indices)
376-
)
369+
assert last_store_indices[idx] is not None
370+
new_store_indices[idx] = bref.spec.index_map(*indices)
377371
are_same_slices = map(
378372
lambda old, new: old == new,
379-
last_store_slices[idx],
380-
new_store_slices[idx],
373+
last_store_indices[idx],
374+
new_store_indices[idx],
381375
)
382376
slices_changed = ~functools.reduce(lax.bitwise_and, are_same_slices)
383377
is_last_step = step == num_steps - 1
@@ -419,7 +413,7 @@ def do_fetch():
419413
return (
420414
_inc_grid_by_1(indices, grid),
421415
next_fetch_indices_levels,
422-
new_store_slices,
416+
new_store_indices,
423417
next_body_carry if init_carry is not None else None,
424418
)
425419

@@ -431,17 +425,17 @@ def do_fetch():
431425
fetch_index_levels.append(fetch_indices)
432426

433427
# TODO(justinfu): Only store base pointer instead of all indices.
434-
last_store_slices = [
428+
last_store_indices = [
435429
None
436430
if bref.is_index_invariant
437-
else (_Slice(-1, -1),) * len(bref.spec.block_shape)
431+
else (jnp.array(-1),) * len(bref.spec.block_shape)
438432
for bref in out_brefs
439433
]
440434
last_indices, _, _, final_carry = lax.fori_loop(
441435
0,
442436
num_steps,
443437
loop_body,
444-
(indices, fetch_index_levels, last_store_slices, init_carry),
438+
(indices, fetch_index_levels, last_store_indices, init_carry),
445439
)
446440

447441
# Outputs invariant to the sequential axis are never written from inside the
@@ -690,7 +684,7 @@ def _get_scoped_allocs(*gmem_refs: AbstractRefPytree):
690684
slots = max_concurrent_steps if has_seq_dim else 1
691685
smem_allocs.append(
692686
gpu_core.SMEM(
693-
(slots, *spec.block_shape), # type: ignore
687+
(slots, *_get_block_shape(spec)), # type: ignore
694688
gmem_ref.dtype,
695689
transforms=getattr(spec, "transforms", ()),
696690
)
@@ -826,7 +820,7 @@ def compute_block():
826820
needs_epilogue = any(bref.is_index_invariant for bref in smem_out_brefs)
827821

828822
def compute_loop_body(step, carry):
829-
indices, last_store_slices, prev_body_carry = carry
823+
indices, last_store_indices, prev_body_carry = carry
830824
slot = lax.rem(step, max_concurrent_steps)
831825
consumed_slot = lax.rem(step - delay_release, max_concurrent_steps)
832826
# Wait for the current GMEM->SMEM copies to complete.
@@ -873,33 +867,32 @@ def compute_loop_body(step, carry):
873867
if copies_out_in_loop:
874868
gpu_primitives.commit_smem()
875869

876-
new_store_slices = last_store_slices[:]
870+
new_store_indices = last_store_indices[:]
877871
for idx, bref in enumerate(flat_out_brefs):
878872
if bref.is_index_invariant:
879-
assert last_store_slices[idx] is None
873+
assert last_store_indices[idx] is None
880874
continue
881-
assert last_store_slices[idx] is not None
882-
new_store_slices[idx] = tuple(
883-
_Slice(s.start, s.size) for s in bref.compute_gmem_slice(indices)
884-
)
875+
assert last_store_indices[idx] is not None
876+
new_store_indices[idx] = bref.spec.index_map(*indices)
885877
are_same_slices = map(
886878
lambda old, new: old == new,
887-
last_store_slices[idx],
888-
new_store_slices[idx],
879+
last_store_indices[idx],
880+
new_store_indices[idx],
889881
)
890882
slices_changed = ~functools.reduce(lax.bitwise_and, are_same_slices)
891883
bref.copy_out(_get_slot(slot, not bref.is_index_invariant),
892884
indices,
893885
predicate=slices_changed)
894886
gpu_primitives.commit_smem_to_gmem_group()
895887
next_indices = _inc_grid_by_1(indices, grid)
896-
return (next_indices, new_store_slices, next_body_carry)
888+
return (next_indices, new_store_indices, next_body_carry)
897889
init_indices = (jnp.asarray(0, dtype=jnp.int32),) * len(grid)
890+
898891
# TODO(justinfu): Only store base pointer instead of all indices.
899-
last_store_slices = [
892+
last_store_indices = [
900893
None
901894
if bref.is_index_invariant
902-
else (_Slice(-1, -1),) * len(bref.spec.block_shape)
895+
else (jnp.array(-1),) * len(bref.spec.block_shape)
903896
for bref in flat_out_brefs
904897
]
905898

@@ -910,7 +903,7 @@ def pipeline_callback(user_init_carry):
910903
if last_indices is not None:
911904
raise ValueError(
912905
"Cannot call pipeline more than once in `compute_context`")
913-
init_loop_carry = (init_indices, last_store_slices, user_init_carry)
906+
init_loop_carry = (init_indices, last_store_indices, user_init_carry)
914907
last_indices, _, final_body_carry = lax.fori_loop(0,
915908
num_steps,
916909
compute_loop_body,
@@ -923,7 +916,7 @@ def pipeline_callback(user_init_carry):
923916
assert compute_context is None
924917
last_indices, _, _ = lax.fori_loop(
925918
0, num_steps, compute_loop_body,
926-
(init_indices, last_store_slices, None)
919+
(init_indices, last_store_indices, None)
927920
)
928921

929922
# Handle index_invariant outputs after the loop. They are not

tests/pallas/mosaic_gpu_test.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4972,6 +4972,35 @@ def kernel_body(_, o_smem, carry):
49724972
kernel_fn(), jnp.tile(jnp.repeat(jnp.arange(num_steps), 64), (64, 1))
49734973
)
49744974

4975+
@parameterized.parameters((pl.Squeezed(),), (None,))
4976+
def test_emit_with_squeezed_dim(self, squeezed_dim):
4977+
4978+
shape = (16, 256)
4979+
num_steps = shape[0]
4980+
4981+
def kernel(x_gmem, o_gmem):
4982+
plgpu.emit_pipeline(
4983+
kernel_body,
4984+
in_specs=[pl.BlockSpec((squeezed_dim, shape[1]), lambda i: (i, 0))],
4985+
out_specs=[pl.BlockSpec((squeezed_dim, shape[1]), lambda i: (i, 0))],
4986+
grid=(num_steps,),
4987+
max_concurrent_steps=2,
4988+
)(x_gmem, o_gmem)
4989+
4990+
def kernel_body(_, in_smem, o_smem):
4991+
assert in_smem.shape == (shape[1],)
4992+
assert o_smem.shape == (shape[1],)
4993+
o_smem[...] = in_smem[...] + 1
4994+
4995+
kernel_fn = self.pallas_call(
4996+
kernel,
4997+
in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
4998+
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
4999+
out_shape=jax.ShapeDtypeStruct((16, 256), jnp.int32),
5000+
)
5001+
x = jnp.arange(16 * 256, dtype=jnp.int32).reshape(16, 256)
5002+
np.testing.assert_array_equal(kernel_fn(x), x + 1)
5003+
49755004

49765005
class PipelineWGTest(
49775006
PipelineTest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup
@@ -5661,6 +5690,38 @@ def pipeline_body(_, x_smem, o_smem):
56615690
)
56625691
np.testing.assert_array_equal(y, np.stack([x + 1.0, x + 1.0]))
56635692

5693+
@parameterized.parameters((pl.Squeezed(),), (None,))
5694+
def test_emit_with_squeezed_dim(self, squeezed_dim):
5695+
self.skip_if_wg_semantics()
5696+
5697+
shape = (16, 256)
5698+
num_steps = shape[0]
5699+
5700+
def kernel(x_gmem, o_gmem):
5701+
plgpu.emit_pipeline_warp_specialized(
5702+
kernel_body,
5703+
in_specs=[pl.BlockSpec((squeezed_dim, shape[1]), lambda i: (i, 0))],
5704+
out_specs=[pl.BlockSpec((squeezed_dim, shape[1]), lambda i: (i, 0))],
5705+
grid=(num_steps,),
5706+
max_concurrent_steps=2,
5707+
num_compute_wgs=1,
5708+
memory_registers=40,
5709+
wg_axis="wg",
5710+
)(x_gmem, o_gmem)
5711+
5712+
def kernel_body(_, in_smem, o_smem):
5713+
o_smem[...] = in_smem[...] + 1
5714+
5715+
kernel_fn = self.kernel(
5716+
kernel,
5717+
out_shape=jax.ShapeDtypeStruct((16, 256), jnp.int32),
5718+
num_threads=2,
5719+
thread_name="wg",
5720+
)
5721+
x = jnp.arange(16 * 256, dtype=jnp.int32).reshape(16, 256)
5722+
np.testing.assert_array_equal(kernel_fn(x), x + 1)
5723+
5724+
56645725

56655726
class WarpSpecializedPipelineWGTest(
56665727
WarpSpecializedPipelineTest,

0 commit comments

Comments
 (0)