From 97440f3ded514689bbe67ceea210db1e11f6d087 Mon Sep 17 00:00:00 2001 From: danielquintas8 Date: Sun, 9 Nov 2025 19:29:37 +0000 Subject: [PATCH 1/2] Refactor routing logic in multiple models to use dim=-1 for softmax in route_tokens_to_experts method and correct argument naming --- src/transformers/models/dbrx/modeling_dbrx.py | 2 +- src/transformers/models/dbrx/modular_dbrx.py | 2 +- .../models/ernie4_5_moe/modeling_ernie4_5_moe.py | 2 +- .../models/ernie4_5_moe/modular_ernie4_5_moe.py | 2 +- .../models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py | 7 ++++--- .../models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py | 7 ++++--- src/transformers/models/jamba/modeling_jamba.py | 6 +++--- src/transformers/models/jamba/modular_jamba.py | 6 +++--- src/transformers/models/olmoe/modeling_olmoe.py | 6 +++--- src/transformers/models/olmoe/modular_olmoe.py | 6 +++--- src/transformers/models/qwen2_moe/modeling_qwen2_moe.py | 6 +++--- src/transformers/models/qwen2_moe/modular_qwen2_moe.py | 6 +++--- src/transformers/models/qwen3_moe/modeling_qwen3_moe.py | 6 +++--- src/transformers/models/qwen3_moe/modular_qwen3_moe.py | 6 +++--- 14 files changed, 36 insertions(+), 34 deletions(-) diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index a3f995d35b95..6d92c40a96a4 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -357,7 +357,7 @@ def __init__(self, config, **kwargs): self.top_k = config.ffn_config.moe_top_k def route_tokens_to_experts(self, router_logits): - router_logits = torch.nn.functional.softmax(router_logits, dim=1, dtype=router_logits.dtype) + router_logits = torch.nn.functional.softmax(router_logits, dim=-1, dtype=router_logits.dtype) router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) if self.moe_normalize_expert_weights is not None: router_top_value = router_top_value / torch.norm( diff --git a/src/transformers/models/dbrx/modular_dbrx.py b/src/transformers/models/dbrx/modular_dbrx.py index 46507e44d52d..754891c17cda 100644 --- a/src/transformers/models/dbrx/modular_dbrx.py +++ b/src/transformers/models/dbrx/modular_dbrx.py @@ -227,7 +227,7 @@ def __init__(self, config, **kwargs): self.top_k = config.ffn_config.moe_top_k def route_tokens_to_experts(self, router_logits): - router_logits = torch.nn.functional.softmax(router_logits, dim=1, dtype=router_logits.dtype) + router_logits = torch.nn.functional.softmax(router_logits, dim=-1, dtype=router_logits.dtype) router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) if self.moe_normalize_expert_weights is not None: router_top_value = router_top_value / torch.norm( diff --git a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py index c2dbd8d436d8..b669b83479e6 100644 --- a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +++ b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py @@ -362,7 +362,7 @@ def route_tokens_to_experts(self, hidden_states): with torch.autocast(device_type=device_type, enabled=False): # Force float32 router_logits = self.gate(hidden_states.float()) - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) _, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1) routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts) routing_weights = routing_weights / torch.clamp( diff --git a/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py b/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py index b12958b785b7..525c348a2847 100644 --- a/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +++ b/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py @@ -143,7 +143,7 @@ def route_tokens_to_experts(self, hidden_states): with torch.autocast(device_type=device_type, enabled=False): # Force float32 router_logits = self.gate(hidden_states.float()) - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) _, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1) routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts) routing_weights = routing_weights / torch.clamp( diff --git a/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py b/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py index 732bafbd336d..ce80b73a0ca3 100644 --- a/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +++ b/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py @@ -289,11 +289,12 @@ def __init__(self, config: HunYuanMoEV1Config, layer_idx: Optional[int] = None): self.experts = HunYuanMoEV1Experts(config) self.shared_mlp = HunYuanMoEV1MLP(config) - def route_tokens_to_experts(self, hidden_states): - routing_weights = F.softmax(hidden_states, dim=1, dtype=torch.float) + def route_tokens_to_experts(self, router_logits): + routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - return selected_experts, routing_weights.to(hidden_states.dtype) + routing_weights = routing_weights.to(router_logits.dtype) + return selected_experts, routing_weights def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape diff --git a/src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py b/src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py index 06269fedf784..da0c8165a21a 100644 --- a/src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +++ b/src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py @@ -145,11 +145,12 @@ def __init__(self, config: HunYuanMoEV1Config, layer_idx: Optional[int] = None): self.experts = HunYuanMoEV1Experts(config) self.shared_mlp = HunYuanMoEV1MLP(config) - def route_tokens_to_experts(self, hidden_states): - routing_weights = F.softmax(hidden_states, dim=1, dtype=torch.float) + def route_tokens_to_experts(self, router_logits): + routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - return selected_experts, routing_weights.to(hidden_states.dtype) + routing_weights = routing_weights.to(router_logits.dtype) + return selected_experts, routing_weights def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 94d8cdc3f7be..f6b425a9ef25 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -614,16 +614,16 @@ def __init__(self, config: JambaConfig): self.router = nn.Linear(self.hidden_dim, self.num_experts, bias=False) self.experts = JambaExperts(config) - def route_tokens_to_experts(self, hidden_states, router_logits): + def route_tokens_to_experts(self, router_logits): routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float) top_k_weights, top_k_index = torch.topk(routing_weights, self.top_k, dim=-1) - return top_k_index, top_k_weights.to(hidden_states.dtype) + return top_k_index, top_k_weights.to(router_logits.dtype) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) router_logits = self.router(hidden_states) - top_k_index, top_k_weights = self.route_tokens_to_experts(hidden_states, router_logits) + top_k_index, top_k_weights = self.route_tokens_to_experts(router_logits) hidden_states = self.experts(hidden_states, top_k_index, top_k_weights) hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim) return hidden_states diff --git a/src/transformers/models/jamba/modular_jamba.py b/src/transformers/models/jamba/modular_jamba.py index c6cfe339fabb..8e638c531f33 100644 --- a/src/transformers/models/jamba/modular_jamba.py +++ b/src/transformers/models/jamba/modular_jamba.py @@ -501,16 +501,16 @@ def __init__(self, config: JambaConfig): self.router = nn.Linear(self.hidden_dim, self.num_experts, bias=False) self.experts = JambaExperts(config) - def route_tokens_to_experts(self, hidden_states, router_logits): + def route_tokens_to_experts(self, router_logits): routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float) top_k_weights, top_k_index = torch.topk(routing_weights, self.top_k, dim=-1) - return top_k_index, top_k_weights.to(hidden_states.dtype) + return top_k_index, top_k_weights.to(router_logits.dtype) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) router_logits = self.router(hidden_states) - top_k_index, top_k_weights = self.route_tokens_to_experts(hidden_states, router_logits) + top_k_index, top_k_weights = self.route_tokens_to_experts(router_logits) hidden_states = self.experts(hidden_states, top_k_index, top_k_weights) hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim) return hidden_states diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index f6034bd9fc6f..97b8005236d7 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -339,19 +339,19 @@ def __init__(self, config): self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False) self.experts = OlmoeExperts(config) - def route_tokens_to_experts(self, hidden_states, router_logits): + def route_tokens_to_experts(self, router_logits): routing_weights = torch.nn.functional.softmax(router_logits.float(), dim=-1) top_k_weights, top_k_index = torch.topk(routing_weights, self.top_k, dim=-1) if self.norm_topk_prob: top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True) - top_k_weights = top_k_weights.to(hidden_states.dtype) + top_k_weights = top_k_weights.to(router_logits.dtype) return top_k_index, top_k_weights def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) router_logits = self.gate(hidden_states) - top_k_index, top_k_weights = self.route_tokens_to_experts(hidden_states, router_logits) + top_k_index, top_k_weights = self.route_tokens_to_experts(router_logits) final_hidden_states = self.experts(hidden_states, top_k_index, top_k_weights).reshape( batch_size, sequence_length, hidden_dim ) diff --git a/src/transformers/models/olmoe/modular_olmoe.py b/src/transformers/models/olmoe/modular_olmoe.py index 8220a0d7a0f0..bba04494c1ef 100644 --- a/src/transformers/models/olmoe/modular_olmoe.py +++ b/src/transformers/models/olmoe/modular_olmoe.py @@ -134,19 +134,19 @@ def __init__(self, config): self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False) self.experts = OlmoeExperts(config) - def route_tokens_to_experts(self, hidden_states, router_logits): + def route_tokens_to_experts(self, router_logits): routing_weights = torch.nn.functional.softmax(router_logits.float(), dim=-1) top_k_weights, top_k_index = torch.topk(routing_weights, self.top_k, dim=-1) if self.norm_topk_prob: top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True) - top_k_weights = top_k_weights.to(hidden_states.dtype) + top_k_weights = top_k_weights.to(router_logits.dtype) return top_k_index, top_k_weights def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) router_logits = self.gate(hidden_states) - top_k_index, top_k_weights = self.route_tokens_to_experts(hidden_states, router_logits) + top_k_index, top_k_weights = self.route_tokens_to_experts(router_logits) final_hidden_states = self.experts(hidden_states, top_k_index, top_k_weights).reshape( batch_size, sequence_length, hidden_dim ) diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index d1e309f612c6..b18238a47ef5 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -335,7 +335,7 @@ def __init__(self, config): self.shared_expert = Qwen2MoeMLP(config, intermediate_size=config.shared_expert_intermediate_size) self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) - def route_tokens_to_experts(self, hidden_states, router_logits): + def route_tokens_to_experts(self, router_logits): routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.num_experts_per_tok, dim=-1) if self.norm_topk_prob: @@ -343,12 +343,12 @@ def route_tokens_to_experts(self, hidden_states, router_logits): routing_weights = routing_weights.to(router_logits.dtype) return selected_experts, routing_weights - def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states_reshaped = hidden_states.view(-1, hidden_dim) shared_expert_output = self.shared_expert(hidden_states_reshaped) router_logits = self.gate(hidden_states_reshaped) - selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states_reshaped, router_logits) + selected_experts, routing_weights = self.route_tokens_to_experts(router_logits) expert_output = self.experts(hidden_states_reshaped, selected_experts, routing_weights) shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) * shared_expert_output diff --git a/src/transformers/models/qwen2_moe/modular_qwen2_moe.py b/src/transformers/models/qwen2_moe/modular_qwen2_moe.py index 56c100f94b93..9e22ceb8ba70 100644 --- a/src/transformers/models/qwen2_moe/modular_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modular_qwen2_moe.py @@ -102,7 +102,7 @@ def __init__(self, config): self.shared_expert = Qwen2MoeMLP(config, intermediate_size=config.shared_expert_intermediate_size) self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) - def route_tokens_to_experts(self, hidden_states, router_logits): + def route_tokens_to_experts(self, router_logits): routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.num_experts_per_tok, dim=-1) if self.norm_topk_prob: @@ -110,12 +110,12 @@ def route_tokens_to_experts(self, hidden_states, router_logits): routing_weights = routing_weights.to(router_logits.dtype) return selected_experts, routing_weights - def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states_reshaped = hidden_states.view(-1, hidden_dim) shared_expert_output = self.shared_expert(hidden_states_reshaped) router_logits = self.gate(hidden_states_reshaped) - selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states_reshaped, router_logits) + selected_experts, routing_weights = self.route_tokens_to_experts(router_logits) expert_output = self.experts(hidden_states_reshaped, selected_experts, routing_weights) shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) * shared_expert_output diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index ff0855c223ee..92f3caded786 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -251,7 +251,7 @@ def __init__(self, config: Qwen3MoeConfig): self.num_experts_per_tok = config.num_experts_per_tok self.norm_topk_prob = config.norm_topk_prob - def route_tokens_to_experts(self, hidden_states, router_logits): + def route_tokens_to_experts(self, router_logits): routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.num_experts_per_tok, dim=-1) if self.norm_topk_prob: @@ -259,11 +259,11 @@ def route_tokens_to_experts(self, hidden_states, router_logits): routing_weights = routing_weights.to(router_logits.dtype) return selected_experts, routing_weights - def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states_reshaped = hidden_states.view(-1, hidden_dim) router_logits = self.gate(hidden_states_reshaped) - selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states_reshaped, router_logits) + selected_experts, routing_weights = self.route_tokens_to_experts(router_logits) final_hidden_states = self.experts(hidden_states_reshaped, selected_experts, routing_weights) return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) diff --git a/src/transformers/models/qwen3_moe/modular_qwen3_moe.py b/src/transformers/models/qwen3_moe/modular_qwen3_moe.py index 87a4bbfa9625..a4a0e53e04fb 100644 --- a/src/transformers/models/qwen3_moe/modular_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modular_qwen3_moe.py @@ -73,7 +73,7 @@ def __init__(self, config: Qwen3MoeConfig): self.num_experts_per_tok = config.num_experts_per_tok self.norm_topk_prob = config.norm_topk_prob - def route_tokens_to_experts(self, hidden_states, router_logits): + def route_tokens_to_experts(self, router_logits): routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.num_experts_per_tok, dim=-1) if self.norm_topk_prob: @@ -81,11 +81,11 @@ def route_tokens_to_experts(self, hidden_states, router_logits): routing_weights = routing_weights.to(router_logits.dtype) return selected_experts, routing_weights - def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states_reshaped = hidden_states.view(-1, hidden_dim) router_logits = self.gate(hidden_states_reshaped) - selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states_reshaped, router_logits) + selected_experts, routing_weights = self.route_tokens_to_experts(router_logits) final_hidden_states = self.experts(hidden_states_reshaped, selected_experts, routing_weights) return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) From 229ab6ce34c5a28b264fef446d25c18a0cd1a04e Mon Sep 17 00:00:00 2001 From: danielquintas8 Date: Sun, 9 Nov 2025 20:11:22 +0000 Subject: [PATCH 2/2] modeling files --- .../models/flex_olmo/modeling_flex_olmo.py | 6 +++--- .../models/qwen3_next/modeling_qwen3_next.py | 6 +++--- .../models/qwen3_omni_moe/modeling_qwen3_omni_moe.py | 12 ++++++------ 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/flex_olmo/modeling_flex_olmo.py b/src/transformers/models/flex_olmo/modeling_flex_olmo.py index 01d10317cf09..032239d3ba3e 100644 --- a/src/transformers/models/flex_olmo/modeling_flex_olmo.py +++ b/src/transformers/models/flex_olmo/modeling_flex_olmo.py @@ -336,19 +336,19 @@ def __init__(self, config): self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False) self.experts = FlexOlmoExperts(config) - def route_tokens_to_experts(self, hidden_states, router_logits): + def route_tokens_to_experts(self, router_logits): routing_weights = torch.nn.functional.softmax(router_logits.float(), dim=-1) top_k_weights, top_k_index = torch.topk(routing_weights, self.top_k, dim=-1) if self.norm_topk_prob: top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True) - top_k_weights = top_k_weights.to(hidden_states.dtype) + top_k_weights = top_k_weights.to(router_logits.dtype) return top_k_index, top_k_weights def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) router_logits = self.gate(hidden_states) - top_k_index, top_k_weights = self.route_tokens_to_experts(hidden_states, router_logits) + top_k_index, top_k_weights = self.route_tokens_to_experts(router_logits) final_hidden_states = self.experts(hidden_states, top_k_index, top_k_weights).reshape( batch_size, sequence_length, hidden_dim ) diff --git a/src/transformers/models/qwen3_next/modeling_qwen3_next.py b/src/transformers/models/qwen3_next/modeling_qwen3_next.py index 3847c43117a3..03c27d8c4d05 100644 --- a/src/transformers/models/qwen3_next/modeling_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modeling_qwen3_next.py @@ -865,7 +865,7 @@ def __init__(self, config): self.shared_expert = Qwen3NextMLP(config, intermediate_size=config.shared_expert_intermediate_size) self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) - def route_tokens_to_experts(self, hidden_states, router_logits): + def route_tokens_to_experts(self, router_logits): routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.num_experts_per_tok, dim=-1) if self.norm_topk_prob: @@ -873,12 +873,12 @@ def route_tokens_to_experts(self, hidden_states, router_logits): routing_weights = routing_weights.to(router_logits.dtype) return selected_experts, routing_weights - def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states_reshaped = hidden_states.view(-1, hidden_dim) shared_expert_output = self.shared_expert(hidden_states_reshaped) router_logits = self.gate(hidden_states_reshaped) - selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states_reshaped, router_logits) + selected_experts, routing_weights = self.route_tokens_to_experts(router_logits) expert_output = self.experts(hidden_states_reshaped, selected_experts, routing_weights) shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) * shared_expert_output diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index aabd906dc3b2..a4145a4be5ef 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -1365,7 +1365,7 @@ def __init__(self, config: Qwen3OmniMoeThinkerConfig): self.num_experts_per_tok = config.num_experts_per_tok self.norm_topk_prob = config.norm_topk_prob - def route_tokens_to_experts(self, hidden_states, router_logits): + def route_tokens_to_experts(self, router_logits): routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.num_experts_per_tok, dim=-1) if self.norm_topk_prob: @@ -1373,11 +1373,11 @@ def route_tokens_to_experts(self, hidden_states, router_logits): routing_weights = routing_weights.to(router_logits.dtype) return selected_experts, routing_weights - def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states_reshaped = hidden_states.view(-1, hidden_dim) router_logits = self.gate(hidden_states_reshaped) - selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states_reshaped, router_logits) + selected_experts, routing_weights = self.route_tokens_to_experts(router_logits) final_hidden_states = self.experts(hidden_states_reshaped, selected_experts, routing_weights) return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) @@ -2755,7 +2755,7 @@ def __init__(self, config): ) self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) - def route_tokens_to_experts(self, hidden_states, router_logits): + def route_tokens_to_experts(self, router_logits): routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.num_experts_per_tok, dim=-1) if self.norm_topk_prob: @@ -2763,12 +2763,12 @@ def route_tokens_to_experts(self, hidden_states, router_logits): routing_weights = routing_weights.to(router_logits.dtype) return selected_experts, routing_weights - def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states_reshaped = hidden_states.view(-1, hidden_dim) shared_expert_output = self.shared_expert(hidden_states_reshaped) router_logits = self.gate(hidden_states_reshaped) - selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states_reshaped, router_logits) + selected_experts, routing_weights = self.route_tokens_to_experts(router_logits) expert_output = self.experts(hidden_states_reshaped, selected_experts, routing_weights) shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) * shared_expert_output