Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 42 additions & 13 deletions jax/_src/pallas/mosaic_gpu/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,15 @@ def _get_block_size(
raise NotImplementedError(f"Unsupported block size type: {type(bd)}")

def _get_block_shape(spec: pallas_core.BlockSpec):
assert spec.block_shape is not None
return tuple(_get_block_size(bd) for bd in spec.block_shape)
if spec.block_shape is None:
raise ValueError("Block shape must be specified.")

block_shape = tuple(
_get_block_size(bd)
for bd in spec.block_shape
if not (bd is None or isinstance(bd, pl.Squeezed))
)
return block_shape


map_brefs = functools.partial(
Expand All @@ -84,18 +91,27 @@ def get_ref_for_slot(
return self.gmem_ref
return self.smem_ref.at[slot]

def compute_gmem_slice(self, grid_indices) -> tuple[pl.Slice, ...]:
def compute_gmem_slice(self, grid_indices) -> tuple[pl.Slice | jax.Array, ...]:
index_map = self.spec.index_map
assert index_map is not None
assert self.spec.block_shape is not None
# We don't allow Python scalars here, because they are interpreted
# differently depending on the x32/x64 mode.
assert all(i.dtype == jnp.dtype(jnp.int32) for i in grid_indices)
sizes = _get_block_shape(self.spec)

def _make_block_slice(block_index: jax.Array, bd: pl.BlockDim | int | None):
match bd:
case int():
return pl.Slice(block_index * bd, bd)
case pl.Blocked(block_size):
return pl.Slice(block_index * block_size, block_size)
case None | pl.Squeezed():
return block_index
case _:
raise ValueError(f"Unsupported block dimension type: {bd}")

return tuple(
pl.Slice(idx * size, size) # type: ignore[arg-type]
for idx, size in zip(
index_map(*grid_indices), sizes # type: ignore[arg-type]
)
map(_make_block_slice, index_map(*grid_indices), self.spec.block_shape)
)

def copy_in(self, slot, grid_indices, barrier_ref, barrier_slot=None):
Expand Down Expand Up @@ -372,7 +388,8 @@ def loop_body(step, carry):
continue
assert last_store_slices[idx] is not None
new_store_slices[idx] = tuple(
_Slice(s.start, s.size) for s in bref.compute_gmem_slice(indices)
_Slice(s.start, s.size) if isinstance(s, pl.Slice) else s
for s in bref.compute_gmem_slice(indices)
)
are_same_slices = map(
lambda old, new: old == new,
Expand Down Expand Up @@ -430,11 +447,16 @@ def do_fetch():
fetch_indices = _inc_grid_by_1(fetch_indices, grid)
fetch_index_levels.append(fetch_indices)

def _init_store_slice(bd):
if bd is None or isinstance(bd, pl.Squeezed):
return jnp.array(-1, dtype=jnp.int32)
return _Slice(-1, -1)

# TODO(justinfu): Only store base pointer instead of all indices.
last_store_slices = [
None
if bref.is_index_invariant
else (_Slice(-1, -1),) * len(bref.spec.block_shape)
else tuple(map(_init_store_slice, bref.spec.block_shape))
for bref in out_brefs
]
last_indices, _, _, final_carry = lax.fori_loop(
Expand Down Expand Up @@ -690,7 +712,7 @@ def _get_scoped_allocs(*gmem_refs: AbstractRefPytree):
slots = max_concurrent_steps if has_seq_dim else 1
smem_allocs.append(
gpu_core.SMEM(
(slots, *spec.block_shape), # type: ignore
(slots, *_get_block_shape(spec)), # type: ignore
gmem_ref.dtype,
transforms=getattr(spec, "transforms", ()),
)
Expand Down Expand Up @@ -880,7 +902,8 @@ def compute_loop_body(step, carry):
continue
assert last_store_slices[idx] is not None
new_store_slices[idx] = tuple(
_Slice(s.start, s.size) for s in bref.compute_gmem_slice(indices)
_Slice(s.start, s.size) if isinstance(s, pl.Slice) else s
for s in bref.compute_gmem_slice(indices)
)
are_same_slices = map(
lambda old, new: old == new,
Expand All @@ -895,11 +918,17 @@ def compute_loop_body(step, carry):
next_indices = _inc_grid_by_1(indices, grid)
return (next_indices, new_store_slices, next_body_carry)
init_indices = (jnp.asarray(0, dtype=jnp.int32),) * len(grid)

def _init_store_slice(bd):
if bd is None or isinstance(bd, pl.Squeezed):
return jnp.array(-1, dtype=jnp.int32)
return _Slice(-1, -1)

# TODO(justinfu): Only store base pointer instead of all indices.
last_store_slices = [
None
if bref.is_index_invariant
else (_Slice(-1, -1),) * len(bref.spec.block_shape)
else tuple(map(_init_store_slice, bref.spec.block_shape))
for bref in flat_out_brefs
]

Expand Down
61 changes: 61 additions & 0 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4972,6 +4972,35 @@ def kernel_body(_, o_smem, carry):
kernel_fn(), jnp.tile(jnp.repeat(jnp.arange(num_steps), 64), (64, 1))
)

@parameterized.parameters((pl.Squeezed(),), (None,))
def test_emit_with_squeezed_dim(self, squeezed_dim):

shape = (16, 256)
num_steps = shape[0]

def kernel(x_gmem, o_gmem):
plgpu.emit_pipeline(
kernel_body,
in_specs=[pl.BlockSpec((squeezed_dim, shape[1]), lambda i: (i, 0))],
out_specs=[pl.BlockSpec((squeezed_dim, shape[1]), lambda i: (i, 0))],
grid=(num_steps,),
max_concurrent_steps=2,
)(x_gmem, o_gmem)

def kernel_body(_, in_smem, o_smem):
assert in_smem.shape == (shape[1],)
assert o_smem.shape == (shape[1],)
o_smem[...] = in_smem[...] + 1

kernel_fn = self.pallas_call(
kernel,
in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
out_shape=jax.ShapeDtypeStruct((16, 256), jnp.int32),
)
x = jnp.arange(16 * 256, dtype=jnp.int32).reshape(16, 256)
np.testing.assert_array_equal(kernel_fn(x), x + 1)


class PipelineWGTest(
PipelineTest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup
Expand Down Expand Up @@ -5661,6 +5690,38 @@ def pipeline_body(_, x_smem, o_smem):
)
np.testing.assert_array_equal(y, np.stack([x + 1.0, x + 1.0]))

@parameterized.parameters((pl.Squeezed(),), (None,))
def test_emit_with_squeezed_dim(self, squeezed_dim):
self.skip_if_wg_semantics()

shape = (16, 256)
num_steps = shape[0]

def kernel(x_gmem, o_gmem):
plgpu.emit_pipeline_warp_specialized(
kernel_body,
in_specs=[pl.BlockSpec((squeezed_dim, shape[1]), lambda i: (i, 0))],
out_specs=[pl.BlockSpec((squeezed_dim, shape[1]), lambda i: (i, 0))],
grid=(num_steps,),
max_concurrent_steps=2,
num_compute_wgs=1,
memory_registers=40,
wg_axis="wg",
)(x_gmem, o_gmem)

def kernel_body(_, in_smem, o_smem):
o_smem[...] = in_smem[...] + 1

kernel_fn = self.kernel(
kernel,
out_shape=jax.ShapeDtypeStruct((16, 256), jnp.int32),
num_threads=2,
thread_name="wg",
)
x = jnp.arange(16 * 256, dtype=jnp.int32).reshape(16, 256)
np.testing.assert_array_equal(kernel_fn(x), x + 1)



class WarpSpecializedPipelineWGTest(
WarpSpecializedPipelineTest,
Expand Down
Loading