Skip to content

Commit df5d0a7

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic TPU] Disable pipelining when the kernel always loads the same oeprand window
This lets us consume less VMEM and is overall more convenient for users. PiperOrigin-RevId: 829402385
1 parent cc8ccfd commit df5d0a7

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

tests/pallas/tpu_pallas_test.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2265,6 +2265,44 @@ def run(num_grid, s, x):
22652265
error_message,
22662266
)
22672267

2268+
def test_automatic_single_buffering(self,):
2269+
if self.INTERPRET:
2270+
self.skipTest('OOM tests need us to compile the kernels')
2271+
if not jtu.if_cloud_tpu_at_least(2025, 11, 12):
2272+
self.skipTest('Support added on Oct 14, 2025')
2273+
2274+
def body(y_ref):
2275+
pass # We only want to compile the kernel.
2276+
2277+
x = jax.ShapeDtypeStruct((100 * 1024 * 1024,), jnp.int8)
2278+
x_small = jax.ShapeDtypeStruct((10 * 1024 * 1024,), jnp.int8)
2279+
# Should recognize that the block specs only load a single window.
2280+
self.pallas_call(body, grid=(4,), out_shape=x_small).lower().compile()
2281+
# Should recognize that the block specs only load a single window, as it
2282+
# only depends on the 1-sized grid dim
2283+
self.pallas_call(
2284+
body, grid=(4, 1), out_shape=x,
2285+
out_specs=pl.BlockSpec((10 * 1024 * 1024,), lambda i, j: (j,))
2286+
).lower().compile()
2287+
self.pallas_call(
2288+
body, grid=(1, 4), out_shape=x,
2289+
out_specs=pl.BlockSpec((10 * 1024 * 1024,), lambda i, j: (i,))
2290+
).lower().compile()
2291+
# Should OOM, as now we are extracting different windows
2292+
with self.assertRaisesRegex(
2293+
jax.errors.JaxRuntimeError, '(Ran out of memory)|(exceed memory)'
2294+
):
2295+
self.pallas_call(
2296+
body, grid=(4, 1), out_shape=x,
2297+
out_specs=pl.BlockSpec((10 * 1024 * 1024,), lambda i, j: (j + i,))
2298+
).lower().compile()
2299+
# Explicitly setting single-buffering should fix it, though.
2300+
self.pallas_call(
2301+
body, grid=(4, 1), out_shape=x,
2302+
out_specs=pl.BlockSpec((10 * 1024 * 1024,),lambda i, j: (j + i,),
2303+
pipeline_mode=pl.Buffered(1))
2304+
).lower().compile()
2305+
22682306
def test_allow_input_fusion(self):
22692307
shape = (3, 128, 128)
22702308

0 commit comments

Comments
 (0)