Skip to content

Commit 18d914c

Browse files
Rifur13Google-ML-Automation
authored andcommitted
[Pallas MGPU] Adding support for squeezed block dims in the pipeline BlockSpecs. They can be identified with a None or pl.Squeezed.
PiperOrigin-RevId: 839975575
1 parent ba024e3 commit 18d914c

File tree

2 files changed

+103
-13
lines changed

2 files changed

+103
-13
lines changed

jax/_src/pallas/mosaic_gpu/pipeline.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,15 @@ def _get_block_size(
5959
raise NotImplementedError(f"Unsupported block size type: {type(bd)}")
6060

6161
def _get_block_shape(spec: pallas_core.BlockSpec):
62-
assert spec.block_shape is not None
63-
return tuple(_get_block_size(bd) for bd in spec.block_shape)
62+
if spec.block_shape is None:
63+
raise ValueError("Block shape must be specified.")
64+
65+
block_shape = tuple(
66+
_get_block_size(bd)
67+
for bd in spec.block_shape
68+
if not (bd is None or isinstance(bd, pl.Squeezed))
69+
)
70+
return block_shape
6471

6572

6673
map_brefs = functools.partial(
@@ -84,18 +91,27 @@ def get_ref_for_slot(
8491
return self.gmem_ref
8592
return self.smem_ref.at[slot]
8693

87-
def compute_gmem_slice(self, grid_indices) -> tuple[pl.Slice, ...]:
94+
def compute_gmem_slice(self, grid_indices) -> tuple[pl.Slice | jax.Array, ...]:
8895
index_map = self.spec.index_map
8996
assert index_map is not None
97+
assert self.spec.block_shape is not None
9098
# We don't allow Python scalars here, because they are interpreted
9199
# differently depending on the x32/x64 mode.
92100
assert all(i.dtype == jnp.dtype(jnp.int32) for i in grid_indices)
93-
sizes = _get_block_shape(self.spec)
101+
102+
def _make_block_slice(block_index: jax.Array, bd: pl.BlockDim | int | None):
103+
match bd:
104+
case int():
105+
return pl.Slice(block_index * bd, bd)
106+
case pl.Blocked(block_size):
107+
return pl.Slice(block_index * block_size, block_size)
108+
case None | pl.Squeezed():
109+
return block_index
110+
case _:
111+
raise ValueError(f"Unsupported block dimension type: {bd}")
112+
94113
return tuple(
95-
pl.Slice(idx * size, size) # type: ignore[arg-type]
96-
for idx, size in zip(
97-
index_map(*grid_indices), sizes # type: ignore[arg-type]
98-
)
114+
map(_make_block_slice, index_map(*grid_indices), self.spec.block_shape)
99115
)
100116

101117
def copy_in(self, slot, grid_indices, barrier_ref, barrier_slot=None):
@@ -372,7 +388,8 @@ def loop_body(step, carry):
372388
continue
373389
assert last_store_slices[idx] is not None
374390
new_store_slices[idx] = tuple(
375-
_Slice(s.start, s.size) for s in bref.compute_gmem_slice(indices)
391+
_Slice(s.start, s.size) if isinstance(s, pl.Slice) else s
392+
for s in bref.compute_gmem_slice(indices)
376393
)
377394
are_same_slices = map(
378395
lambda old, new: old == new,
@@ -430,11 +447,16 @@ def do_fetch():
430447
fetch_indices = _inc_grid_by_1(fetch_indices, grid)
431448
fetch_index_levels.append(fetch_indices)
432449

450+
def _init_store_slice(bd):
451+
if bd is None or isinstance(bd, pl.Squeezed):
452+
return jnp.array(-1, dtype=jnp.int32)
453+
return _Slice(-1, -1)
454+
433455
# TODO(justinfu): Only store base pointer instead of all indices.
434456
last_store_slices = [
435457
None
436458
if bref.is_index_invariant
437-
else (_Slice(-1, -1),) * len(bref.spec.block_shape)
459+
else tuple(map(_init_store_slice, bref.spec.block_shape))
438460
for bref in out_brefs
439461
]
440462
last_indices, _, _, final_carry = lax.fori_loop(
@@ -690,7 +712,7 @@ def _get_scoped_allocs(*gmem_refs: AbstractRefPytree):
690712
slots = max_concurrent_steps if has_seq_dim else 1
691713
smem_allocs.append(
692714
gpu_core.SMEM(
693-
(slots, *spec.block_shape), # type: ignore
715+
(slots, *_get_block_shape(spec)), # type: ignore
694716
gmem_ref.dtype,
695717
transforms=getattr(spec, "transforms", ()),
696718
)
@@ -880,7 +902,8 @@ def compute_loop_body(step, carry):
880902
continue
881903
assert last_store_slices[idx] is not None
882904
new_store_slices[idx] = tuple(
883-
_Slice(s.start, s.size) for s in bref.compute_gmem_slice(indices)
905+
_Slice(s.start, s.size) if isinstance(s, pl.Slice) else s
906+
for s in bref.compute_gmem_slice(indices)
884907
)
885908
are_same_slices = map(
886909
lambda old, new: old == new,
@@ -895,11 +918,17 @@ def compute_loop_body(step, carry):
895918
next_indices = _inc_grid_by_1(indices, grid)
896919
return (next_indices, new_store_slices, next_body_carry)
897920
init_indices = (jnp.asarray(0, dtype=jnp.int32),) * len(grid)
921+
922+
def _init_store_slice(bd):
923+
if bd is None or isinstance(bd, pl.Squeezed):
924+
return jnp.array(-1, dtype=jnp.int32)
925+
return _Slice(-1, -1)
926+
898927
# TODO(justinfu): Only store base pointer instead of all indices.
899928
last_store_slices = [
900929
None
901930
if bref.is_index_invariant
902-
else (_Slice(-1, -1),) * len(bref.spec.block_shape)
931+
else tuple(map(_init_store_slice, bref.spec.block_shape))
903932
for bref in flat_out_brefs
904933
]
905934

tests/pallas/mosaic_gpu_test.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4972,6 +4972,35 @@ def kernel_body(_, o_smem, carry):
49724972
kernel_fn(), jnp.tile(jnp.repeat(jnp.arange(num_steps), 64), (64, 1))
49734973
)
49744974

4975+
@parameterized.parameters((pl.Squeezed(),), (None,))
4976+
def test_emit_with_squeezed_dim(self, squeezed_dim):
4977+
4978+
shape = (16, 256)
4979+
num_steps = shape[0]
4980+
4981+
def kernel(x_gmem, o_gmem):
4982+
plgpu.emit_pipeline(
4983+
kernel_body,
4984+
in_specs=[pl.BlockSpec((squeezed_dim, shape[1]), lambda i: (i, 0))],
4985+
out_specs=[pl.BlockSpec((squeezed_dim, shape[1]), lambda i: (i, 0))],
4986+
grid=(num_steps,),
4987+
max_concurrent_steps=2,
4988+
)(x_gmem, o_gmem)
4989+
4990+
def kernel_body(_, in_smem, o_smem):
4991+
assert in_smem.shape == (shape[1],)
4992+
assert o_smem.shape == (shape[1],)
4993+
o_smem[...] = in_smem[...] + 1
4994+
4995+
kernel_fn = self.pallas_call(
4996+
kernel,
4997+
in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
4998+
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
4999+
out_shape=jax.ShapeDtypeStruct((16, 256), jnp.int32),
5000+
)
5001+
x = jnp.arange(16 * 256, dtype=jnp.int32).reshape(16, 256)
5002+
np.testing.assert_array_equal(kernel_fn(x), x + 1)
5003+
49755004

49765005
class PipelineWGTest(
49775006
PipelineTest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup
@@ -5661,6 +5690,38 @@ def pipeline_body(_, x_smem, o_smem):
56615690
)
56625691
np.testing.assert_array_equal(y, np.stack([x + 1.0, x + 1.0]))
56635692

5693+
@parameterized.parameters((pl.Squeezed(),), (None,))
5694+
def test_emit_with_squeezed_dim(self, squeezed_dim):
5695+
self.skip_if_wg_semantics()
5696+
5697+
shape = (16, 256)
5698+
num_steps = shape[0]
5699+
5700+
def kernel(x_gmem, o_gmem):
5701+
plgpu.emit_pipeline_warp_specialized(
5702+
kernel_body,
5703+
in_specs=[pl.BlockSpec((squeezed_dim, shape[1]), lambda i: (i, 0))],
5704+
out_specs=[pl.BlockSpec((squeezed_dim, shape[1]), lambda i: (i, 0))],
5705+
grid=(num_steps,),
5706+
max_concurrent_steps=2,
5707+
num_compute_wgs=1,
5708+
memory_registers=40,
5709+
wg_axis="wg",
5710+
)(x_gmem, o_gmem)
5711+
5712+
def kernel_body(_, in_smem, o_smem):
5713+
o_smem[...] = in_smem[...] + 1
5714+
5715+
kernel_fn = self.kernel(
5716+
kernel,
5717+
out_shape=jax.ShapeDtypeStruct((16, 256), jnp.int32),
5718+
num_threads=2,
5719+
thread_name="wg",
5720+
)
5721+
x = jnp.arange(16 * 256, dtype=jnp.int32).reshape(16, 256)
5722+
np.testing.assert_array_equal(kernel_fn(x), x + 1)
5723+
5724+
56645725

56655726
class WarpSpecializedPipelineWGTest(
56665727
WarpSpecializedPipelineTest,

0 commit comments

Comments
 (0)