Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 26 additions & 18 deletions jax/_src/pallas/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -1351,20 +1351,26 @@ def _semaphore_wait_discharge_rule(in_avals,
)


def _device_id_dict_to_mesh(mesh_context: pallas_utils.MeshInfo, device_id_dict, get_axis_index):
def _device_id_dict_to_mesh(mesh_context: pallas_utils.MeshInfo | None, device_id_dict, get_axis_index):
i32 = ir.IntegerType.get_signless(32)
assert mesh_context is not None
mesh_axis_sizes = dict(zip(mesh_context.axis_names, mesh_context.mesh_shape))
if mesh_context is None:
mesh_axis_sizes = {}
else:
mesh_axis_sizes = dict(
zip(mesh_context.axis_names, mesh_context.mesh_shape)
)
physical_axis_dict = {}
# Handle joint axes (i.e., one logical axis over >1 physical axes)
for axis, idx in device_id_dict.items():
if isinstance(axis, tuple) and any(a in mesh_context.axis_names for a in axis):
if not all(a in mesh_context.axis_names for a in axis):
for axis_name, idx in device_id_dict.items():
if isinstance(axis_name, tuple) and any(
a in mesh_axis_sizes for a in axis_name
):
if not all(a in mesh_axis_sizes for a in axis_name):
raise NotImplementedError(
f"{axis} mixes JAX mesh and Pallas mesh grid axes"
f"{axis_name} mixes JAX mesh and Pallas mesh grid axes"
)
axes_dimensions = [mesh_axis_sizes[name] for name in axis]
for axis_index, axis_name in enumerate(axis):
axes_dimensions = [mesh_axis_sizes[name] for name in axis_name]
for axis_index, axis_name in enumerate(axis_name):
axis_size = mesh_axis_sizes[axis_name]
inner_mesh_size = math.prod(axes_dimensions[axis_index + 1 :])
minor_divisor = arith.constant(i32, inner_mesh_size)
Expand All @@ -1387,17 +1393,17 @@ def _device_id_dict_to_mesh(mesh_context: pallas_utils.MeshInfo, device_id_dict,
)
physical_axis_dict[axis_name] = device_idx
else:
physical_axis_dict[axis] = idx
physical_axis_dict[axis_name] = idx
device_id = []
for axis in mesh_context.axis_names:
if axis in physical_axis_dict:
device_id.append(physical_axis_dict[axis])
for axis_name in mesh_axis_sizes:
if axis_name in physical_axis_dict:
device_id.append(physical_axis_dict[axis_name])
else:
device_id.append(get_axis_index(axis))
device_id.append(get_axis_index(axis_name))
non_mesh_axes = {
k: v
for k, v in physical_axis_dict.items()
if k not in mesh_context.axis_names
if k not in mesh_axis_sizes
}
return tuple(device_id), non_mesh_axes

Expand All @@ -1419,13 +1425,15 @@ def device_id_to_logical(
"`device_id_type` must be MESH if `device_id` is a dict,"
f" got: {device_id_type = }."
)
assert mesh_context is not None
device_id, non_mesh_axes = _device_id_dict_to_mesh(mesh_context, device_id, get_axis_index)
if device_id_type is DeviceIdType.MESH:
assert mesh_context is not None
# Mesh means we are passed the mesh coordinates for the device
device_ids = tree_util.tree_leaves(device_id)
mesh_strides = mesh_context.mesh_strides
mesh_strides: tuple[int, ...]
if mesh_context is None:
mesh_strides = ()
else:
mesh_strides = mesh_context.mesh_strides
if len(device_ids) != len(mesh_strides):
raise ValueError(
"Number of device ids must match the number of mesh axes, but got"
Expand Down
18 changes: 18 additions & 0 deletions tests/pallas/tpu_pallas_state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,24 @@ def _():
"Attempted to lower core_map without discharging."):
f(x)

def test_can_signal_cores(self):
@jax.jit
def f(x):
x_ref = jax.new_ref(x)
y_ref = jax.new_ref(jnp.empty_like(x))
@pl.core_map(pltpu.create_tensorcore_mesh("x"))
def _():
@functools.partial(pl.run_scoped, sem=pltpu.SemaphoreType.REGULAR)
def inner(sem):
s = jax.lax.axis_size("x")
for i in range(s):
pl.semaphore_signal(sem, device_id={"x": i})
pl.semaphore_wait(sem, s)
pltpu.sync_copy(x_ref, y_ref)
return jax.freeze(y_ref)
x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128))
np.testing.assert_array_equal(f(x), x)

def test_can_query_core_index(self):
mesh = pltpu.create_tensorcore_mesh("x")
slc_size = 16 // mesh.shape["x"]
Expand Down
Loading