We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 1828c75 commit ef0618fCopy full SHA for ef0618f
jax/_src/pallas/fuser/block_spec.py
@@ -2164,7 +2164,7 @@ def _transpose_push_rule(
2164
) -> pallas_core.BlockSpec:
2165
del ctx
2166
block_shape = block_spec.block_shape
2167
- new_shape = [block_shape[i] for i in permutation]
+ new_shape = tuple(block_shape[i] for i in permutation)
2168
if set(permutation[-2:]) != {permutation[-1], permutation[-2]}:
2169
raise NotImplementedError(
2170
'Cannot permute last two dimensions with leading dimensions.'
0 commit comments