Skip to content

Commit 44f67b0

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Add concrete_mesh to reshard_p in pallas lowering rules which was missing
PiperOrigin-RevId: 845766226
1 parent d3a8052 commit 44f67b0

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

jax/_src/pallas/mosaic/lowering.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3393,7 +3393,8 @@ def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_):
33933393

33943394

33953395
@register_lowering_rule(pjit.reshard_p)
3396-
def _reshard_lowering_rule(ctx: LoweringRuleContext, x, dst_sharding):
3396+
def _reshard_lowering_rule(ctx: LoweringRuleContext, x, *, dst_sharding,
3397+
concrete_mesh):
33973398
return x
33983399

33993400

jax/_src/pallas/triton/lowering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2537,7 +2537,7 @@ def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_):
25372537

25382538

25392539
@register_lowering(pjit.reshard_p)
2540-
def _reshard_lowering_rule(ctx, x, dst_sharding):
2540+
def _reshard_lowering_rule(ctx, x, *, dst_sharding, concrete_mesh):
25412541
return x
25422542

25432543

0 commit comments

Comments
 (0)