Skip to content
2 changes: 1 addition & 1 deletion src/transformers/integrations/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weig
The `is_weight` is important because for weights, we want to support `.weights` and `.bias` cases seamlessly! but
not parent classes for `post_init` calls
"""
generic_param_name = re.sub(r"\d+", "*", parameter_name)
generic_param_name = re.sub(r"\.\d+\.", ".*.", parameter_name)
if generic_param_name in tp_plan:
return tp_plan[generic_param_name]
elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan and is_weight:
Expand Down