Skip to content

Commit faef876

Browse files
Rifur13Google-ML-Automation
authored andcommitted
[Pallas MGPU] Adding support for squeezed block dims in the pipeline BlockSpecs. They can be identified with a None or pl.Squeezed.
PiperOrigin-RevId: 839975575
1 parent 10a43df commit faef876

File tree

2 files changed

+90
-11
lines changed

2 files changed

+90
-11
lines changed

jax/_src/pallas/mosaic_gpu/pipeline.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,15 @@ def _get_block_size(
5959
raise NotImplementedError(f"Unsupported block size type: {type(bd)}")
6060

6161
def _get_block_shape(spec: pallas_core.BlockSpec):
62-
assert spec.block_shape is not None
63-
return tuple(_get_block_size(bd) for bd in spec.block_shape)
62+
if spec.block_shape is None:
63+
raise ValueError("Block shape must be specified.")
64+
65+
block_shape = tuple(
66+
_get_block_size(bd)
67+
for bd in spec.block_shape
68+
if not (bd is None or isinstance(bd, pl.Squeezed))
69+
)
70+
return block_shape
6471

6572

6673
map_brefs = functools.partial(
@@ -84,18 +91,27 @@ def get_ref_for_slot(
8491
return self.gmem_ref
8592
return self.smem_ref.at[slot]
8693

87-
def compute_gmem_slice(self, grid_indices) -> tuple[pl.Slice, ...]:
94+
def compute_gmem_slice(self, grid_indices) -> tuple[pl.Slice | jax.Array, ...]:
8895
index_map = self.spec.index_map
8996
assert index_map is not None
9097
# We don't allow Python scalars here, because they are interpreted
9198
# differently depending on the x32/x64 mode.
9299
assert all(i.dtype == jnp.dtype(jnp.int32) for i in grid_indices)
93-
sizes = _get_block_shape(self.spec)
100+
101+
def _make_block_slice(block_index: jax.Array, bd: pl.BlockDim):
102+
match bd:
103+
case int():
104+
return pl.Slice(block_index * bd, bd)
105+
case pl.Blocked(block_size):
106+
return pl.Slice(block_index * block_size, block_size)
107+
case None | pl.Squeezed():
108+
return block_index
109+
case _:
110+
raise ValueError(f"Unsupported block dimension type: {bd}")
111+
94112
return tuple(
95-
pl.Slice(idx * size, size) # type: ignore[arg-type]
96-
for idx, size in zip(
97-
index_map(*grid_indices), sizes # type: ignore[arg-type]
98-
)
113+
_make_block_slice(bd, idx)
114+
for bd, idx in zip(index_map(*grid_indices), self.spec.block_shape)
99115
)
100116

101117
def copy_in(self, slot, grid_indices, barrier_ref, barrier_slot=None):
@@ -177,6 +193,10 @@ class _Slice:
177193
start: int | jax.Array
178194
size: int | jax.Array
179195

196+
@classmethod
197+
def from_val(cls, s: pl.Slice| jax.Array):
198+
return cls(s.start, s.size) if isinstance(s, pl.Slice) else cls(s, 1)
199+
180200
def __eq__(self, other: _Slice) -> jax.Array: # type: ignore
181201
return lax.bitwise_and(self.start == other.start, self.size == other.size)
182202

@@ -372,7 +392,7 @@ def loop_body(step, carry):
372392
continue
373393
assert last_store_slices[idx] is not None
374394
new_store_slices[idx] = tuple(
375-
_Slice(s.start, s.size) for s in bref.compute_gmem_slice(indices)
395+
_Slice.from_val(s) for s in bref.compute_gmem_slice(indices)
376396
)
377397
are_same_slices = map(
378398
lambda old, new: old == new,
@@ -690,7 +710,7 @@ def _get_scoped_allocs(*gmem_refs: AbstractRefPytree):
690710
slots = max_concurrent_steps if has_seq_dim else 1
691711
smem_allocs.append(
692712
gpu_core.SMEM(
693-
(slots, *spec.block_shape), # type: ignore
713+
(slots, *_get_block_shape(spec)), # type: ignore
694714
gmem_ref.dtype,
695715
transforms=getattr(spec, "transforms", ()),
696716
)
@@ -880,7 +900,7 @@ def compute_loop_body(step, carry):
880900
continue
881901
assert last_store_slices[idx] is not None
882902
new_store_slices[idx] = tuple(
883-
_Slice(s.start, s.size) for s in bref.compute_gmem_slice(indices)
903+
_Slice.from_val(s) for s in bref.compute_gmem_slice(indices)
884904
)
885905
are_same_slices = map(
886906
lambda old, new: old == new,

tests/pallas/mosaic_gpu_test.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4734,6 +4734,33 @@ def kernel_body(_, o_smem, carry):
47344734
kernel_fn(), jnp.tile(jnp.repeat(jnp.arange(num_steps), 64), (64, 1))
47354735
)
47364736

4737+
@parameterized.parameters((pl.Squeezed(),), (None,))
4738+
def test_emit_with_squeezed_dim(self, squeezed_dim):
4739+
4740+
shape = (16, 256)
4741+
num_steps = shape[0]
4742+
4743+
def kernel(x_gmem, o_gmem):
4744+
plgpu.emit_pipeline(
4745+
kernel_body,
4746+
in_specs=[pl.BlockSpec((squeezed_dim, shape[1]), lambda i: (i, 0))],
4747+
out_specs=[pl.BlockSpec((squeezed_dim, shape[1]), lambda i: (i, 0))],
4748+
grid=(num_steps,),
4749+
max_concurrent_steps=2,
4750+
)(x_gmem, o_gmem)
4751+
4752+
def kernel_body(_, in_smem, o_smem):
4753+
o_smem[...] = in_smem[...] + 1
4754+
4755+
kernel_fn = self.pallas_call(
4756+
kernel,
4757+
in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
4758+
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
4759+
out_shape=jax.ShapeDtypeStruct((16, 256), jnp.int32),
4760+
)
4761+
x = jnp.arange(16 * 256, dtype=jnp.int32).reshape(16, 256)
4762+
np.testing.assert_array_equal(kernel_fn(x), x + 1)
4763+
47374764

47384765
class PipelineWGTest(
47394766
PipelineTest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup
@@ -5423,6 +5450,38 @@ def pipeline_body(_, x_smem, o_smem):
54235450
)
54245451
np.testing.assert_array_equal(y, np.stack([x + 1.0, x + 1.0]))
54255452

5453+
@parameterized.parameters((pl.Squeezed(),), (None,))
5454+
def test_emit_with_squeezed_dim(self, squeezed_dim):
5455+
self.skip_if_wg_semantics()
5456+
5457+
shape = (16, 256)
5458+
num_steps = shape[0]
5459+
5460+
def kernel(x_gmem, o_gmem):
5461+
plgpu.emit_pipeline_warp_specialized(
5462+
kernel_body,
5463+
in_specs=[pl.BlockSpec((squeezed_dim, shape[1]), lambda i: (i, 0))],
5464+
out_specs=[pl.BlockSpec((squeezed_dim, shape[1]), lambda i: (i, 0))],
5465+
grid=(num_steps,),
5466+
max_concurrent_steps=2,
5467+
num_compute_wgs=1,
5468+
memory_registers=40,
5469+
wg_axis="wg",
5470+
)(x_gmem, o_gmem)
5471+
5472+
def kernel_body(_, in_smem, o_smem):
5473+
o_smem[...] = in_smem[...] + 1
5474+
5475+
kernel_fn = self.kernel(
5476+
kernel,
5477+
out_shape=jax.ShapeDtypeStruct((16, 256), jnp.int32),
5478+
num_threads=2,
5479+
thread_name="wg",
5480+
)
5481+
x = jnp.arange(16 * 256, dtype=jnp.int32).reshape(16, 256)
5482+
np.testing.assert_array_equal(kernel_fn(x), x + 1)
5483+
5484+
54265485

54275486
class WarpSpecializedPipelineWGTest(
54285487
WarpSpecializedPipelineTest,

0 commit comments

Comments
 (0)