Skip to content

Commit dc03b06

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[pallas:sc] Skipped a few more tests under TC tiling due to verification errors
PiperOrigin-RevId: 833257620
1 parent e63d2a4 commit dc03b06

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

jaxlib/mosaic/dialect/tpu/tpu_dialect.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ std::pair<bool, bool> mightCommunicateBetweenChips(Operation *op);
5959

6060
std::unique_ptr<OperationPass<func::FuncOp>> createInferMemRefLayoutPass(
6161
int hardware_generation, std::array<int64_t, 2> target_shape,
62-
const TpuTilingFlags& tpu_tiling_flags, bool align = true);
62+
const TpuTilingFlags& tpu_tiling_flags, bool align = true,
63+
bool infer_kernel_arguments = true);
6364

6465
#define GEN_PASS_DECL_MOSAICSERDEPASS
6566
#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc"

tests/pallas/tpu_sparsecore_pallas_test.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,13 @@ def kernel(x_ref, o_ref):
296296
@hp.given(hps.data())
297297
def test_block_spec_untiled_slicing(self, data):
298298
if not self.USE_TC_TILING:
299-
self.skipTest("Test uncoveres a bug: @reproduce_failure('6.80.0', b'AAEBAQAAAAA=')")
299+
self.skipTest(
300+
"Test uncovers a bug: @reproduce_failure('6.80.0', b'AAEBAQAAAAA=')"
301+
)
302+
else:
303+
self.skipTest(
304+
"Test uncovers a bug: @reproduce_failure('6.80.0', b'AAEAAQAAAAA=')"
305+
)
300306
slice_shape = data.draw(
301307
hps.lists(
302308
hps.integers(1, 3), min_size=(1 + self.USE_TC_TILING), max_size=4
@@ -1065,7 +1071,7 @@ def scoped_kernel(scratch_ref):
10651071
pl.run_scoped(
10661072
scoped_kernel,
10671073
plsc.MemoryRef(
1068-
x.shape, x_ref.dtype, memory_space=pltpu.VMEM, tiling=[(2, 1)]
1074+
x.shape, x_ref.dtype, memory_space=pltpu.VMEM, tiling=[(1, 8)]
10691075
),
10701076
)
10711077

0 commit comments

Comments
 (0)