@@ -928,7 +928,7 @@ def _allreduce_impl(prim, pos_reducer, *args, axes, axis_index_groups):
928928 return [pos_reducer (arg , axes ) for arg in args ]
929929
930930def _allreduce_effectful_abstract_eval (* args , axes , axis_index_groups ):
931- _check_axis_names (axes )
931+ _check_axis_names (axes , 'psum' )
932932 named_axes = tuple (axis for axis in axes if not isinstance (axis , int ))
933933 pos_axes = tuple (axis for axis in axes if isinstance (axis , int ))
934934 if axis_index_groups is not None :
@@ -949,7 +949,7 @@ def _psum_invariant_abstract_eval(name, *args, axes, axis_index_groups):
949949 * args , axes = axes , axis_index_groups = axis_index_groups )
950950
951951 assert isinstance (axes , tuple )
952- _check_axis_names (axes )
952+ _check_axis_names (axes , 'psum' )
953953 arg_vma = [a .vma for a in args ]
954954 # If intersection between arg_vma and axes is empty, error
955955 if any (not set (axes ) & a for a in arg_vma ):
@@ -985,12 +985,14 @@ def _pmin_pmax_abstract_eval(name, *args, axes, axis_index_groups):
985985 return _psum_invariant_abstract_eval (
986986 name , * args , axes = axes , axis_index_groups = axis_index_groups )
987987
988- def _check_axis_names (axes ):
988+ def _check_axis_names (axes , api_name ):
989989 named_axes = tuple (axis for axis in axes if not isinstance (axis , int ))
990990 axis_env = core .get_axis_env ()
991991 for name in named_axes :
992992 if not axis_env .axis_exists (name ):
993- raise NameError (f"unbound axis name: { name } " )
993+ raise NameError (
994+ f"Found an unbound axis name: { name } . To fix this, please call"
995+ f" { api_name } under `jax.shard_map`." )
994996
995997def _allreduce_lowering (prim , pos_fn , ctx , * args , axes , axis_index_groups ):
996998 if axis_index_groups is not None and ("tpu" in ctx .module_context .platforms ):
@@ -1166,7 +1168,7 @@ def _ppermute_batcher(axis_data, vals_in, dims_in, axis_name, perm):
11661168 return v .take (perm_indices , d ), d
11671169
11681170def _raise_to_shaped_abstract_eval (x , * , axis_name , ** params ):
1169- _check_axis_names (axis_name )
1171+ _check_axis_names (axis_name , 'ppermute' )
11701172 collective_vma_rule ('ppermute' , axis_name , x )
11711173 return x
11721174
@@ -1218,7 +1220,7 @@ def _psend_lowering_gpu(ctx, x, *, axis_name, perm):
12181220
12191221
12201222def _psend_abstract_eval (x , * , axis_name , ** params ):
1221- _check_axis_names (axis_name )
1223+ _check_axis_names (axis_name , 'psend' )
12221224 return abstract_token , {
12231225 * map (core .NamedAxisEffect , axis_name ),
12241226 single_side_collective_effect ,
@@ -1492,7 +1494,7 @@ def _all_to_all_effectful_abstract_eval(
14921494 del tiled # expand_dims and squeeze is done in `all_to_all` if `True`
14931495 if not isinstance (axis_name , (list , tuple )):
14941496 axis_name = (axis_name ,)
1495- _check_axis_names (axis_name )
1497+ _check_axis_names (axis_name , 'all_to_all' )
14961498 shape = list (input_aval .shape )
14971499 axis_size = (
14981500 _axis_size (axis_name )
@@ -1581,7 +1583,7 @@ def _ragged_all_to_all_effectful_abstract_eval(
15811583 " size, but got shape {}" .format (recv_sizes .shape )
15821584 )
15831585
1584- _check_axis_names (axis_name )
1586+ _check_axis_names (axis_name , 'ragged_all_to_all' )
15851587 out_aval = output .update (shape = output .shape , weak_type = False )
15861588 effects = {* map (core .NamedAxisEffect , axis_name )}
15871589 return out_aval , effects
@@ -1802,7 +1804,7 @@ def _all_gather_effectful_abstract_eval(
18021804):
18031805 if not isinstance (axis_name , (list , tuple )):
18041806 axis_name = (axis_name ,)
1805- _check_axis_names (axis_name )
1807+ _check_axis_names (axis_name , 'all_gather' )
18061808 new_shape = list (x_aval .shape )
18071809 if tiled :
18081810 new_shape [all_gather_dimension ] *= axis_size
@@ -1920,7 +1922,7 @@ def bind(leaf):
19201922def _all_gather_invariant_effectful_abstract_eval (
19211923 x_aval , * , all_gather_dimension , axis_name , axis_size , tiled
19221924):
1923- _check_axis_names (axis_name )
1925+ _check_axis_names (axis_name , 'all_gather_invariant' )
19241926 new_shape = list (x_aval .shape )
19251927 if tiled :
19261928 new_shape [all_gather_dimension ] *= axis_size
@@ -2026,7 +2028,7 @@ def _reduce_scatter_effectful_abstract_eval(
20262028):
20272029 if not isinstance (axis_name , (list , tuple )):
20282030 axis_name = (axis_name ,)
2029- _check_axis_names (axis_name )
2031+ _check_axis_names (axis_name , 'reduce_scatter' )
20302032 new_shape = list (x_aval .shape )
20312033 scatter_dim_input_size = x_aval .shape [scatter_dimension ]
20322034 if tiled :
@@ -2244,7 +2246,7 @@ def _axis_index_lowering(ctx, *, axis_name):
22442246def _axis_index_effectful_abstract_eval (* , axis_name ):
22452247 effect = {core .NamedAxisEffect (axis_name )}
22462248 axis_name = (axis_name ,) if not isinstance (axis_name , tuple ) else axis_name
2247- _check_axis_names (axis_name )
2249+ _check_axis_names (axis_name , 'axis_index' )
22482250 mesh = get_abstract_mesh ()
22492251 sharding = NamedSharding (mesh , P ())
22502252 vma = ((frozenset (axis_name ) if mesh ._any_axis_manual else frozenset ())
@@ -2280,7 +2282,7 @@ def _pgather_impl(src, idx, *, axes):
22802282def _pgather_abstract_eval (src , idx , * , axes ):
22812283 # TODO: Avals with names rule: remove all axes from src, insert those from idx
22822284 # The order is important, because it is ok to re-insert one of the deleted axes!
2283- _check_axis_names (axes )
2285+ _check_axis_names (axes , 'pgather' )
22842286 shape = list (src .shape )
22852287 for axis in sorted ((a for a in axes if isinstance (a , int )), reverse = True ):
22862288 del shape [axis ]
0 commit comments