@@ -140,6 +140,16 @@ def _blocks_to_block_sizes(total_size: int, blocks: int | list[int]) -> list[int
140140 return [single_size ] * blocks
141141
142142
143+ def replace_layer_number_by_wildcard (name : str ) -> str :
144+ """
145+ Replace the numbers in the `name` by wildcards, only if they are in-between dots (`.`) or if they are between
146+ a dot (`.`) and the end of the string.
147+ This matches how modules are named/numbered when using a nn.ModuleList or nn.Sequential, but will NOT match
148+ numbers in a parameter name itself, e.g. if the param is named `"w1"` or `"w2"`.
149+ """
150+ return re .sub (r"\.\d+(\.|$)" , lambda m : ".*" + m .group (1 ), name )
151+
152+
143153def _get_parameter_tp_plan (parameter_name : str , tp_plan : dict [str , str ], is_weight = True ) -> str | None :
144154 """
145155 Get the TP style for a parameter from the TP plan.
@@ -150,11 +160,11 @@ def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weig
150160 The `is_weight` is important because for weights, we want to support `.weights` and `.bias` cases seamlessly! but
151161 not parent classes for `post_init` calls
152162 """
153- generic_param_name = re . sub ( r"\d+" , "*" , parameter_name )
163+ generic_param_name = replace_layer_number_by_wildcard ( parameter_name )
154164 if generic_param_name in tp_plan :
155165 return tp_plan [generic_param_name ]
156- elif "." in generic_param_name and generic_param_name .rsplit ("." , 1 )[0 ] in tp_plan and is_weight :
157- return tp_plan [generic_param_name . rsplit ( "." , 1 )[ 0 ] ]
166+ elif is_weight and "." in generic_param_name and ( module_name := generic_param_name .rsplit ("." , 1 )[0 ]) in tp_plan :
167+ return tp_plan [module_name ]
158168 return None
159169
160170
@@ -1086,7 +1096,7 @@ def verify_tp_plan(expected_keys: list[str], tp_plan: dict[str, str] | None):
10861096 if tp_plan is None :
10871097 return
10881098
1089- generic_keys = {re . sub ( r"\d+" , "*" , key ) for key in expected_keys }
1099+ generic_keys = {replace_layer_number_by_wildcard ( key ) for key in expected_keys }
10901100 unsharded_layers = set (generic_keys )
10911101 unused_rules = tp_plan
10921102
0 commit comments