Skip to content

Commit c4cfc2e

Browse files
authored
[TP] Fix parameter detection issue and some invalid TP-plans (#42129)
* fix * add test * fix test * fix the obvious * more fix * fix * continue to improve * more fix * more * fix * fix * finally * CI
1 parent 5c6d6be commit c4cfc2e

20 files changed

+124
-93
lines changed

src/transformers/integrations/tensor_parallel.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,16 @@ def _blocks_to_block_sizes(total_size: int, blocks: int | list[int]) -> list[int
140140
return [single_size] * blocks
141141

142142

143+
def replace_layer_number_by_wildcard(name: str) -> str:
144+
"""
145+
Replace the numbers in the `name` by wildcards, only if they are in-between dots (`.`) or if they are between
146+
a dot (`.`) and the end of the string.
147+
This matches how modules are named/numbered when using a nn.ModuleList or nn.Sequential, but will NOT match
148+
numbers in a parameter name itself, e.g. if the param is named `"w1"` or `"w2"`.
149+
"""
150+
return re.sub(r"\.\d+(\.|$)", lambda m: ".*" + m.group(1), name)
151+
152+
143153
def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weight=True) -> str | None:
144154
"""
145155
Get the TP style for a parameter from the TP plan.
@@ -150,11 +160,11 @@ def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weig
150160
The `is_weight` is important because for weights, we want to support `.weights` and `.bias` cases seamlessly! but
151161
not parent classes for `post_init` calls
152162
"""
153-
generic_param_name = re.sub(r"\d+", "*", parameter_name)
163+
generic_param_name = replace_layer_number_by_wildcard(parameter_name)
154164
if generic_param_name in tp_plan:
155165
return tp_plan[generic_param_name]
156-
elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan and is_weight:
157-
return tp_plan[generic_param_name.rsplit(".", 1)[0]]
166+
elif is_weight and "." in generic_param_name and (module_name := generic_param_name.rsplit(".", 1)[0]) in tp_plan:
167+
return tp_plan[module_name]
158168
return None
159169

160170

@@ -1086,7 +1096,7 @@ def verify_tp_plan(expected_keys: list[str], tp_plan: dict[str, str] | None):
10861096
if tp_plan is None:
10871097
return
10881098

1089-
generic_keys = {re.sub(r"\d+", "*", key) for key in expected_keys}
1099+
generic_keys = {replace_layer_number_by_wildcard(key) for key in expected_keys}
10901100
unsharded_layers = set(generic_keys)
10911101
unused_rules = tp_plan
10921102

src/transformers/models/apertus/configuration_apertus.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ class ApertusConfig(PreTrainedConfig):
106106
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
107107
"layers.*.mlp.up_proj": "colwise",
108108
"layers.*.mlp.down_proj": "rowwise",
109-
"layers.*.mlp.gate_proj": "colwise",
110109
}
111110
base_model_pp_plan = {
112111
"embed_tokens": (["input_ids"], ["inputs_embeds"]),

src/transformers/models/apertus/modular_apertus.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@ class ApertusConfig(LlamaConfig):
123123
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
124124
"layers.*.mlp.up_proj": "colwise",
125125
"layers.*.mlp.down_proj": "rowwise",
126-
"layers.*.mlp.gate_proj": "colwise",
127126
}
128127

129128
def __init__(

src/transformers/models/aria/configuration_aria.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,14 @@ class AriaTextConfig(PreTrainedConfig):
9999

100100
model_type = "aria_text"
101101
keys_to_ignore_at_inference = ["past_key_values"]
102-
# Default tensor parallel plan for base model `AriaTextModel`
103102
base_model_tp_plan = {
104103
"layers.*.self_attn.q_proj": "colwise",
105104
"layers.*.self_attn.k_proj": "colwise",
106105
"layers.*.self_attn.v_proj": "colwise",
107106
"layers.*.self_attn.o_proj": "rowwise",
108-
"layers.*.mlp.gate_proj": "colwise",
109-
"layers.*.mlp.up_proj": "colwise",
110-
"layers.*.mlp.down_proj": "rowwise",
107+
"layers.*.mlp.shared_experts.gate_proj": "colwise",
108+
"layers.*.mlp.shared_experts.up_proj": "colwise",
109+
"layers.*.mlp.shared_experts.down_proj": "rowwise",
111110
}
112111
base_model_pp_plan = {
113112
"embed_tokens": (["input_ids"], ["inputs_embeds"]),

src/transformers/models/aria/modular_aria.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,15 @@ class AriaTextConfig(LlamaConfig):
169169

170170
model_type = "aria_text"
171171
base_config_key = "text_config"
172+
base_model_tp_plan = {
173+
"layers.*.self_attn.q_proj": "colwise",
174+
"layers.*.self_attn.k_proj": "colwise",
175+
"layers.*.self_attn.v_proj": "colwise",
176+
"layers.*.self_attn.o_proj": "rowwise",
177+
"layers.*.mlp.shared_experts.gate_proj": "colwise",
178+
"layers.*.mlp.shared_experts.up_proj": "colwise",
179+
"layers.*.mlp.shared_experts.down_proj": "rowwise",
180+
}
172181

173182
def __init__(
174183
self,

src/transformers/models/doge/configuration_doge.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,9 @@ class DogeConfig(PreTrainedConfig):
118118
"layers.*.self_attn.dt_proj": "rowwise",
119119
"layers.*.self_attn.o_proj": "rowwise",
120120
"layers.*.input_layernorm.weight": "sequence_parallel",
121-
"layers.*.input_residual.weight": "sequence_parallel",
121+
"layers.*.input_residual": "sequence_parallel",
122122
"layers.*.post_attention_layernorm.weight": "sequence_parallel",
123-
"layers.*.post_attention_residual.weight": "sequence_parallel",
123+
"layers.*.post_attention_residual": "sequence_parallel",
124124
"norm.weight": "sequence_parallel",
125125
"layers.*.mlp.gate_proj": "colwise",
126126
"layers.*.mlp.up_proj": "colwise",

src/transformers/models/doge/modular_doge.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,9 @@ class DogeConfig(PreTrainedConfig):
146146
"layers.*.self_attn.dt_proj": "rowwise",
147147
"layers.*.self_attn.o_proj": "rowwise",
148148
"layers.*.input_layernorm.weight": "sequence_parallel",
149-
"layers.*.input_residual.weight": "sequence_parallel",
149+
"layers.*.input_residual": "sequence_parallel",
150150
"layers.*.post_attention_layernorm.weight": "sequence_parallel",
151-
"layers.*.post_attention_residual.weight": "sequence_parallel",
151+
"layers.*.post_attention_residual": "sequence_parallel",
152152
"norm.weight": "sequence_parallel",
153153
"layers.*.mlp.gate_proj": "colwise",
154154
"layers.*.mlp.up_proj": "colwise",

src/transformers/models/flex_olmo/configuration_flex_olmo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,9 @@ class FlexOlmoConfig(PreTrainedConfig):
114114
"layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
115115
"layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
116116
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
117-
"layers.*.mlp.gate_proj": "colwise",
118-
"layers.*.mlp.up_proj": "colwise",
119-
"layers.*.mlp.down_proj": "rowwise",
117+
"layers.*.mlp.experts.*.gate_proj": "colwise",
118+
"layers.*.mlp.experts.*.up_proj": "colwise",
119+
"layers.*.mlp.experts.*.down_proj": "rowwise",
120120
}
121121
base_model_pp_plan = {
122122
"embed_tokens": (["input_ids"], ["inputs_embeds"]),

src/transformers/models/flex_olmo/modular_flex_olmo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,9 @@ class FlexOlmoConfig(OlmoeConfig):
125125
"layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
126126
"layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
127127
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
128-
"layers.*.mlp.gate_proj": "colwise",
129-
"layers.*.mlp.up_proj": "colwise",
130-
"layers.*.mlp.down_proj": "rowwise",
128+
"layers.*.mlp.experts.*.gate_proj": "colwise",
129+
"layers.*.mlp.experts.*.up_proj": "colwise",
130+
"layers.*.mlp.experts.*.down_proj": "rowwise",
131131
}
132132
base_model_pp_plan = {
133133
"embed_tokens": (["input_ids"], ["inputs_embeds"]),

src/transformers/models/gemma3/modeling_gemma3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,7 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
630630
_tp_plan = {"lm_head": "colwise_rep"}
631631
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
632632
config: Gemma3TextConfig
633-
base_model_prefix = "language_model"
633+
base_model_prefix = "model"
634634

635635
def __init__(self, config: Gemma3TextConfig):
636636
super().__init__(config)

0 commit comments

Comments
 (0)