Skip to content

Commit c9e2090

Browse files
authored
fix: Support PP for Mistral Small 3.1 (#14254)
1 parent 106df4e commit c9e2090

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

python/sglang/srt/models/llava.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
MultimodalDataItem,
4242
MultimodalInputs,
4343
)
44-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
44+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
4545
from sglang.srt.model_loader.weight_utils import default_weight_loader
4646
from sglang.srt.models.llama import LlamaForCausalLM
4747
from sglang.srt.models.mistral import MistralForCausalLM
@@ -785,6 +785,7 @@ def forward(
785785
positions: torch.Tensor,
786786
forward_batch: ForwardBatch,
787787
get_embedding: bool = False,
788+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
788789
):
789790
hidden_states = general_mm_embed_routine(
790791
input_ids=input_ids,
@@ -796,6 +797,7 @@ def forward(
796797
},
797798
placeholder_tokens=None, # using mm_item.pad_value
798799
positions=positions,
800+
pp_proxy_tensors=pp_proxy_tensors,
799801
)
800802

801803
return hidden_states

0 commit comments

Comments
 (0)