@@ -2265,6 +2265,44 @@ def run(num_grid, s, x):
22652265 error_message ,
22662266 )
22672267
2268+ def test_automatic_single_buffering (self ,):
2269+ if self .INTERPRET :
2270+ self .skipTest ('OOM tests need us to compile the kernels' )
2271+ if not jtu .if_cloud_tpu_at_least (2025 , 11 , 12 ):
2272+ self .skipTest ('Support added on Oct 14, 2025' )
2273+
2274+ def body (y_ref ):
2275+ pass # We only want to compile the kernel.
2276+
2277+ x = jax .ShapeDtypeStruct ((100 * 1024 * 1024 ,), jnp .int8 )
2278+ x_small = jax .ShapeDtypeStruct ((10 * 1024 * 1024 ,), jnp .int8 )
2279+ # Should recognize that the block specs only load a single window.
2280+ self .pallas_call (body , grid = (4 ,), out_shape = x_small ).lower ().compile ()
2281+ # Should recognize that the block specs only load a single window, as it
2282+ # only depends on the 1-sized grid dim
2283+ self .pallas_call (
2284+ body , grid = (4 , 1 ), out_shape = x ,
2285+ out_specs = pl .BlockSpec ((10 * 1024 * 1024 ,), lambda i , j : (j ,))
2286+ ).lower ().compile ()
2287+ self .pallas_call (
2288+ body , grid = (1 , 4 ), out_shape = x ,
2289+ out_specs = pl .BlockSpec ((10 * 1024 * 1024 ,), lambda i , j : (i ,))
2290+ ).lower ().compile ()
2291+ # Should OOM, as now we are extracting different windows
2292+ with self .assertRaisesRegex (
2293+ jax .errors .JaxRuntimeError , '(Ran out of memory)|(exceed memory)'
2294+ ):
2295+ self .pallas_call (
2296+ body , grid = (4 , 1 ), out_shape = x ,
2297+ out_specs = pl .BlockSpec ((10 * 1024 * 1024 ,), lambda i , j : (j + i ,))
2298+ ).lower ().compile ()
2299+ # Explicitly setting single-buffering should fix it, though.
2300+ self .pallas_call (
2301+ body , grid = (4 , 1 ), out_shape = x ,
2302+ out_specs = pl .BlockSpec ((10 * 1024 * 1024 ,),lambda i , j : (j + i ,),
2303+ pipeline_mode = pl .Buffered (1 ))
2304+ ).lower ().compile ()
2305+
22682306 def test_allow_input_fusion (self ):
22692307 shape = (3 , 128 , 128 )
22702308
0 commit comments