@@ -1351,20 +1351,26 @@ def _semaphore_wait_discharge_rule(in_avals,
13511351)
13521352
13531353
1354- def _device_id_dict_to_mesh (mesh_context : pallas_utils .MeshInfo , device_id_dict , get_axis_index ):
1354+ def _device_id_dict_to_mesh (mesh_context : pallas_utils .MeshInfo | None , device_id_dict , get_axis_index ):
13551355 i32 = ir .IntegerType .get_signless (32 )
1356- assert mesh_context is not None
1357- mesh_axis_sizes = dict (zip (mesh_context .axis_names , mesh_context .mesh_shape ))
1356+ if mesh_context is None :
1357+ mesh_axis_sizes = {}
1358+ else :
1359+ mesh_axis_sizes = dict (
1360+ zip (mesh_context .axis_names , mesh_context .mesh_shape )
1361+ )
13581362 physical_axis_dict = {}
13591363 # Handle joint axes (i.e., one logical axis over >1 physical axes)
1360- for axis , idx in device_id_dict .items ():
1361- if isinstance (axis , tuple ) and any (a in mesh_context .axis_names for a in axis ):
1362- if not all (a in mesh_context .axis_names for a in axis ):
1364+ for axis_name , idx in device_id_dict .items ():
1365+ if isinstance (axis_name , tuple ) and any (
1366+ a in mesh_axis_sizes for a in axis_name
1367+ ):
1368+ if not all (a in mesh_axis_sizes for a in axis_name ):
13631369 raise NotImplementedError (
1364- f"{ axis } mixes JAX mesh and Pallas mesh grid axes"
1370+ f"{ axis_name } mixes JAX mesh and Pallas mesh grid axes"
13651371 )
1366- axes_dimensions = [mesh_axis_sizes [name ] for name in axis ]
1367- for axis_index , axis_name in enumerate (axis ):
1372+ axes_dimensions = [mesh_axis_sizes [name ] for name in axis_name ]
1373+ for axis_index , axis_name in enumerate (axis_name ):
13681374 axis_size = mesh_axis_sizes [axis_name ]
13691375 inner_mesh_size = math .prod (axes_dimensions [axis_index + 1 :])
13701376 minor_divisor = arith .constant (i32 , inner_mesh_size )
@@ -1387,17 +1393,17 @@ def _device_id_dict_to_mesh(mesh_context: pallas_utils.MeshInfo, device_id_dict,
13871393 )
13881394 physical_axis_dict [axis_name ] = device_idx
13891395 else :
1390- physical_axis_dict [axis ] = idx
1396+ physical_axis_dict [axis_name ] = idx
13911397 device_id = []
1392- for axis in mesh_context . axis_names :
1393- if axis in physical_axis_dict :
1394- device_id .append (physical_axis_dict [axis ])
1398+ for axis_name in mesh_axis_sizes :
1399+ if axis_name in physical_axis_dict :
1400+ device_id .append (physical_axis_dict [axis_name ])
13951401 else :
1396- device_id .append (get_axis_index (axis ))
1402+ device_id .append (get_axis_index (axis_name ))
13971403 non_mesh_axes = {
13981404 k : v
13991405 for k , v in physical_axis_dict .items ()
1400- if k not in mesh_context . axis_names
1406+ if k not in mesh_axis_sizes
14011407 }
14021408 return tuple (device_id ), non_mesh_axes
14031409
@@ -1419,13 +1425,15 @@ def device_id_to_logical(
14191425 "`device_id_type` must be MESH if `device_id` is a dict,"
14201426 f" got: { device_id_type = } ."
14211427 )
1422- assert mesh_context is not None
14231428 device_id , non_mesh_axes = _device_id_dict_to_mesh (mesh_context , device_id , get_axis_index )
14241429 if device_id_type is DeviceIdType .MESH :
1425- assert mesh_context is not None
14261430 # Mesh means we are passed the mesh coordinates for the device
14271431 device_ids = tree_util .tree_leaves (device_id )
1428- mesh_strides = mesh_context .mesh_strides
1432+ mesh_strides : tuple [int , ...]
1433+ if mesh_context is None :
1434+ mesh_strides = ()
1435+ else :
1436+ mesh_strides = mesh_context .mesh_strides
14291437 if len (device_ids ) != len (mesh_strides ):
14301438 raise ValueError (
14311439 "Number of device ids must match the number of mesh axes, but got"
0 commit comments