Skip to content

Commit e2815d5

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[pallas:mosaic] pltpu.emit_pipeline now accepts block specs in HBM
This makes it possible to use it to implement pipelining in the `pallas_call` lowering on SparseCore. PiperOrigin-RevId: 846255621
1 parent 69a6cca commit e2815d5

File tree

2 files changed

+34
-14
lines changed

2 files changed

+34
-14
lines changed

jax/_src/pallas/mosaic/pipeline.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040

4141
SMEM = tpu_core.MemorySpace.SMEM
4242
VMEM = tpu_core.MemorySpace.VMEM
43+
HBM = tpu_core.MemorySpace.HBM
4344
ANY = pallas_core.MemorySpace.ANY
4445
REF = pallas_core.MemoryRef
4546
GridDimensionSemantics = tpu_core.GridDimensionSemantics
@@ -582,13 +583,13 @@ def create(
582583
accum_ref = VMEM.from_type(ty.update(shape=block_shape))
583584
else:
584585
accum_ref = None
585-
if source_memory_space == VMEM:
586-
# We don't need to do any double-buffering in the case that our pipeline
587-
# reference is already in VMEM, we just need allocate the accumulation
588-
# buffer and we will refer to the original reference slices directly.
589-
if spec.memory_space not in (VMEM, None):
590-
raise ValueError(
591-
f"Cannot hold a non-buffered ref in {spec.memory_space=}")
586+
buffer_memory_space = (
587+
VMEM if spec.memory_space is None else spec.memory_space)
588+
if buffer_memory_space not in (SMEM, VMEM, HBM):
589+
raise ValueError(
590+
f"Unsupported buffer memory space: {buffer_memory_space}"
591+
)
592+
if source_memory_space is buffer_memory_space:
592593
return cls(
593594
_spec=spec,
594595
_buffer_type=buffer_type,
@@ -609,12 +610,6 @@ def create(
609610
swap=None,
610611
)
611612
else:
612-
buffer_memory_space = (
613-
VMEM if spec.memory_space is None else spec.memory_space)
614-
if buffer_memory_space not in (SMEM, VMEM):
615-
raise ValueError(
616-
f"Unsupported buffer memory space: {buffer_memory_space}"
617-
)
618613
if use_lookahead and grid_rank is None:
619614
raise ValueError(
620615
"grid_rank must be specified when use_lookahead is True."
@@ -1335,7 +1330,7 @@ def out_of_fetch(self, buffered_ref):
13351330
# Currently this is based on the iteration, but if we want to support
13361331
# lookahead this will depend on whether the lookahead reached the end.
13371332
if not buffered_ref.is_buffered:
1338-
return False
1333+
return jnp.bool(False)
13391334
return self.step >= (self.num_steps - buffered_ref.buffer_count + 1)
13401335

13411336
def has_changed(self, buffered_ref):

tests/pallas/tpu_pallas_pipeline_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,31 @@ def body(o_ref):
148148
)()
149149
np.testing.assert_allclose(out, jnp.full_like(out, 42))
150150

151+
def test_hbm_output(self):
152+
@functools.partial(
153+
pl.pallas_call,
154+
out_shape=jax.ShapeDtypeStruct((8, 512), jnp.int32),
155+
in_specs=[pl.BlockSpec(memory_space=pltpu.HBM)],
156+
out_specs=pl.BlockSpec(memory_space=pltpu.HBM),
157+
)
158+
def kernel(x_hbm_ref, o_hbm_ref):
159+
@functools.partial(
160+
pltpu.emit_pipeline,
161+
grid=(4,),
162+
in_specs=pl.BlockSpec((8, 128), lambda i: (0, i)),
163+
out_specs=pl.BlockSpec(
164+
(8, 512), lambda i: (0, 0), memory_space=pltpu.HBM
165+
),
166+
)
167+
def pipeline(x_ref, o_ref):
168+
i = pl.program_id(0)
169+
pltpu.sync_copy(x_ref, o_ref.at[:, pl.ds(i * 128, 128)])
170+
171+
pipeline(x_hbm_ref, o_hbm_ref)
172+
173+
x = jnp.arange(8 * 512).reshape(8, 512)
174+
np.testing.assert_allclose(kernel(x), x)
175+
151176
@parameterized.product(
152177
no_pipelining=[False, True],
153178
)

0 commit comments

Comments
 (0)