Skip to content

Commit ef0618f

Browse files
[pallas fuser] transpose push rule returns block shape as tuple
PiperOrigin-RevId: 832465208
1 parent 1828c75 commit ef0618f

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

jax/_src/pallas/fuser/block_spec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2164,7 +2164,7 @@ def _transpose_push_rule(
21642164
) -> pallas_core.BlockSpec:
21652165
del ctx
21662166
block_shape = block_spec.block_shape
2167-
new_shape = [block_shape[i] for i in permutation]
2167+
new_shape = tuple(block_shape[i] for i in permutation)
21682168
if set(permutation[-2:]) != {permutation[-1], permutation[-2]}:
21692169
raise NotImplementedError(
21702170
'Cannot permute last two dimensions with leading dimensions.'

0 commit comments

Comments
 (0)