Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 23 additions & 59 deletions jax/_src/pallas/mosaic_gpu/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,26 +182,6 @@ def _inc_grid_by_1(
def _in_smem(spec: pallas_core.BlockSpec) -> bool:
return spec.memory_space in (None, gpu_core.SMEM)


# ``pl.Slice`` uses a different pytree encoding, depending on whether the
# start/size are static or dynamic. This leads to pytree structure mismatch
# in the pipeline body. So, we define a different ``Slice`` class below.


@dataclasses.dataclass(frozen=True)
class _Slice:
start: int | jax.Array
size: int | jax.Array

def __eq__(self, other: _Slice) -> jax.Array: # type: ignore
return lax.bitwise_and(self.start == other.start, self.size == other.size)


jax.tree_util.register_dataclass(
_Slice, data_fields=["start", "size"], meta_fields=[]
)


def _downcast_spec(
spec: gpu_core.BlockSpec | pallas_core.BlockSpec,
) -> gpu_core.BlockSpec:
Expand Down Expand Up @@ -357,7 +337,7 @@ def prologue(step, fetch_indices):
# need to fetch more data anyway.
def loop_body(step, carry):
slot = lax.rem(step, max_concurrent_steps)
indices, fetch_index_levels, last_store_slices, prev_body_carry = carry
indices, fetch_index_levels, last_store_indices, prev_body_carry = carry

if barrier_ref is not None:
# Wait for the current GMEM->SMEM copy to complete, if any.
Expand All @@ -381,20 +361,17 @@ def loop_body(step, carry):
gpu_primitives.commit_smem()

# Copy the output from SMEM to GMEM.
new_store_slices = last_store_slices[:]
new_store_indices = last_store_indices[:]
for idx, bref in enumerate(out_brefs):
if bref.is_index_invariant:
assert last_store_slices[idx] is None
assert last_store_indices[idx] is None
continue
assert last_store_slices[idx] is not None
new_store_slices[idx] = tuple(
_Slice(s.start, s.size) if isinstance(s, pl.Slice) else s
for s in bref.compute_gmem_slice(indices)
)
assert last_store_indices[idx] is not None
new_store_indices[idx] = bref.spec.index_map(*indices)
are_same_slices = map(
lambda old, new: old == new,
last_store_slices[idx],
new_store_slices[idx],
last_store_indices[idx],
new_store_indices[idx],
)
slices_changed = ~functools.reduce(lax.bitwise_and, are_same_slices)
is_last_step = step == num_steps - 1
Expand Down Expand Up @@ -436,7 +413,7 @@ def do_fetch():
return (
_inc_grid_by_1(indices, grid),
next_fetch_indices_levels,
new_store_slices,
new_store_indices,
next_body_carry if init_carry is not None else None,
)

Expand All @@ -447,23 +424,18 @@ def do_fetch():
fetch_indices = _inc_grid_by_1(fetch_indices, grid)
fetch_index_levels.append(fetch_indices)

def _init_store_slice(bd):
if bd is None or isinstance(bd, pl.Squeezed):
return jnp.array(-1, dtype=jnp.int32)
return _Slice(-1, -1)

# TODO(justinfu): Only store base pointer instead of all indices.
last_store_slices = [
last_store_indices = [
None
if bref.is_index_invariant
else tuple(map(_init_store_slice, bref.spec.block_shape))
else (jnp.array(-1),) * len(bref.spec.block_shape)
for bref in out_brefs
]
last_indices, _, _, final_carry = lax.fori_loop(
0,
num_steps,
loop_body,
(indices, fetch_index_levels, last_store_slices, init_carry),
(indices, fetch_index_levels, last_store_indices, init_carry),
)

# Outputs invariant to the sequential axis are never written from inside the
Expand Down Expand Up @@ -848,7 +820,7 @@ def compute_block():
needs_epilogue = any(bref.is_index_invariant for bref in smem_out_brefs)

def compute_loop_body(step, carry):
indices, last_store_slices, prev_body_carry = carry
indices, last_store_indices, prev_body_carry = carry
slot = lax.rem(step, max_concurrent_steps)
consumed_slot = lax.rem(step - delay_release, max_concurrent_steps)
# Wait for the current GMEM->SMEM copies to complete.
Expand Down Expand Up @@ -895,40 +867,32 @@ def compute_loop_body(step, carry):
if copies_out_in_loop:
gpu_primitives.commit_smem()

new_store_slices = last_store_slices[:]
new_store_indices = last_store_indices[:]
for idx, bref in enumerate(flat_out_brefs):
if bref.is_index_invariant:
assert last_store_slices[idx] is None
assert last_store_indices[idx] is None
continue
assert last_store_slices[idx] is not None
new_store_slices[idx] = tuple(
_Slice(s.start, s.size) if isinstance(s, pl.Slice) else s
for s in bref.compute_gmem_slice(indices)
)
assert last_store_indices[idx] is not None
new_store_indices[idx] = bref.spec.index_map(*indices)
are_same_slices = map(
lambda old, new: old == new,
last_store_slices[idx],
new_store_slices[idx],
last_store_indices[idx],
new_store_indices[idx],
)
slices_changed = ~functools.reduce(lax.bitwise_and, are_same_slices)
bref.copy_out(_get_slot(slot, not bref.is_index_invariant),
indices,
predicate=slices_changed)
gpu_primitives.commit_smem_to_gmem_group()
next_indices = _inc_grid_by_1(indices, grid)
return (next_indices, new_store_slices, next_body_carry)
return (next_indices, new_store_indices, next_body_carry)
init_indices = (jnp.asarray(0, dtype=jnp.int32),) * len(grid)

def _init_store_slice(bd):
if bd is None or isinstance(bd, pl.Squeezed):
return jnp.array(-1, dtype=jnp.int32)
return _Slice(-1, -1)

# TODO(justinfu): Only store base pointer instead of all indices.
last_store_slices = [
last_store_indices = [
None
if bref.is_index_invariant
else tuple(map(_init_store_slice, bref.spec.block_shape))
else (jnp.array(-1),) * len(bref.spec.block_shape)
for bref in flat_out_brefs
]

Expand All @@ -939,7 +903,7 @@ def pipeline_callback(user_init_carry):
if last_indices is not None:
raise ValueError(
"Cannot call pipeline more than once in `compute_context`")
init_loop_carry = (init_indices, last_store_slices, user_init_carry)
init_loop_carry = (init_indices, last_store_indices, user_init_carry)
last_indices, _, final_body_carry = lax.fori_loop(0,
num_steps,
compute_loop_body,
Expand All @@ -952,7 +916,7 @@ def pipeline_callback(user_init_carry):
assert compute_context is None
last_indices, _, _ = lax.fori_loop(
0, num_steps, compute_loop_body,
(init_indices, last_store_slices, None)
(init_indices, last_store_indices, None)
)

# Handle index_invariant outputs after the loop. They are not
Expand Down
Loading