Skip to content

Commit ad2b914

Browse files
sharadmvGoogle-ML-Automation
authored andcommitted
[Pallas] Allow no mesh context if just signaling on core axis
PiperOrigin-RevId: 845435562
1 parent ba024e3 commit ad2b914

File tree

2 files changed

+44
-18
lines changed

2 files changed

+44
-18
lines changed

jax/_src/pallas/primitives.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1351,20 +1351,26 @@ def _semaphore_wait_discharge_rule(in_avals,
13511351
)
13521352

13531353

1354-
def _device_id_dict_to_mesh(mesh_context: pallas_utils.MeshInfo, device_id_dict, get_axis_index):
1354+
def _device_id_dict_to_mesh(mesh_context: pallas_utils.MeshInfo | None, device_id_dict, get_axis_index):
13551355
i32 = ir.IntegerType.get_signless(32)
1356-
assert mesh_context is not None
1357-
mesh_axis_sizes = dict(zip(mesh_context.axis_names, mesh_context.mesh_shape))
1356+
if mesh_context is None:
1357+
mesh_axis_sizes = {}
1358+
else:
1359+
mesh_axis_sizes = dict(
1360+
zip(mesh_context.axis_names, mesh_context.mesh_shape)
1361+
)
13581362
physical_axis_dict = {}
13591363
# Handle joint axes (i.e., one logical axis over >1 physical axes)
1360-
for axis, idx in device_id_dict.items():
1361-
if isinstance(axis, tuple) and any(a in mesh_context.axis_names for a in axis):
1362-
if not all(a in mesh_context.axis_names for a in axis):
1364+
for axis_name, idx in device_id_dict.items():
1365+
if isinstance(axis_name, tuple) and any(
1366+
a in mesh_axis_sizes for a in axis_name
1367+
):
1368+
if not all(a in mesh_axis_sizes for a in axis_name):
13631369
raise NotImplementedError(
1364-
f"{axis} mixes JAX mesh and Pallas mesh grid axes"
1370+
f"{axis_name} mixes JAX mesh and Pallas mesh grid axes"
13651371
)
1366-
axes_dimensions = [mesh_axis_sizes[name] for name in axis]
1367-
for axis_index, axis_name in enumerate(axis):
1372+
axes_dimensions = [mesh_axis_sizes[name] for name in axis_name]
1373+
for axis_index, axis_name in enumerate(axis_name):
13681374
axis_size = mesh_axis_sizes[axis_name]
13691375
inner_mesh_size = math.prod(axes_dimensions[axis_index + 1 :])
13701376
minor_divisor = arith.constant(i32, inner_mesh_size)
@@ -1387,17 +1393,17 @@ def _device_id_dict_to_mesh(mesh_context: pallas_utils.MeshInfo, device_id_dict,
13871393
)
13881394
physical_axis_dict[axis_name] = device_idx
13891395
else:
1390-
physical_axis_dict[axis] = idx
1396+
physical_axis_dict[axis_name] = idx
13911397
device_id = []
1392-
for axis in mesh_context.axis_names:
1393-
if axis in physical_axis_dict:
1394-
device_id.append(physical_axis_dict[axis])
1398+
for axis_name in mesh_axis_sizes:
1399+
if axis_name in physical_axis_dict:
1400+
device_id.append(physical_axis_dict[axis_name])
13951401
else:
1396-
device_id.append(get_axis_index(axis))
1402+
device_id.append(get_axis_index(axis_name))
13971403
non_mesh_axes = {
13981404
k: v
13991405
for k, v in physical_axis_dict.items()
1400-
if k not in mesh_context.axis_names
1406+
if k not in mesh_axis_sizes
14011407
}
14021408
return tuple(device_id), non_mesh_axes
14031409

@@ -1419,13 +1425,15 @@ def device_id_to_logical(
14191425
"`device_id_type` must be MESH if `device_id` is a dict,"
14201426
f" got: {device_id_type = }."
14211427
)
1422-
assert mesh_context is not None
14231428
device_id, non_mesh_axes = _device_id_dict_to_mesh(mesh_context, device_id, get_axis_index)
14241429
if device_id_type is DeviceIdType.MESH:
1425-
assert mesh_context is not None
14261430
# Mesh means we are passed the mesh coordinates for the device
14271431
device_ids = tree_util.tree_leaves(device_id)
1428-
mesh_strides = mesh_context.mesh_strides
1432+
mesh_strides: tuple[int, ...]
1433+
if mesh_context is None:
1434+
mesh_strides = ()
1435+
else:
1436+
mesh_strides = mesh_context.mesh_strides
14291437
if len(device_ids) != len(mesh_strides):
14301438
raise ValueError(
14311439
"Number of device ids must match the number of mesh axes, but got"

tests/pallas/tpu_pallas_state_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,24 @@ def _():
273273
"Attempted to lower core_map without discharging."):
274274
f(x)
275275

276+
def test_can_signal_cores(self):
277+
@jax.jit
278+
def f(x):
279+
x_ref = jax.new_ref(x)
280+
y_ref = jax.new_ref(jnp.empty_like(x))
281+
@pl.core_map(pltpu.create_tensorcore_mesh("x"))
282+
def _():
283+
@functools.partial(pl.run_scoped, sem=pltpu.SemaphoreType.REGULAR)
284+
def inner(sem):
285+
s = jax.lax.axis_size("x")
286+
for i in range(s):
287+
pl.semaphore_signal(sem, device_id={"x": i})
288+
pl.semaphore_wait(sem, s)
289+
pltpu.sync_copy(x_ref, y_ref)
290+
return jax.freeze(y_ref)
291+
x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128))
292+
np.testing.assert_array_equal(f(x), x)
293+
276294
def test_can_query_core_index(self):
277295
mesh = pltpu.create_tensorcore_mesh("x")
278296
slc_size = 16 // mesh.shape["x"]

0 commit comments

Comments
 (0)