Skip to content

Commit 21ec99b

Browse files
yuan-luoluoyuan.luo
andauthored
[VLM][Doc] Document for VLM DP Encoder (#14279)
Co-authored-by: luoyuan.luo <[email protected]>
1 parent 383689e commit 21ec99b

File tree

1 file changed

+69
-0
lines changed

1 file changed

+69
-0
lines changed
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# DP for Multi-Modal Encoder in SGLang
2+
3+
A typical VLM architecture involves two main components: an multi-modal encoder and a text decoder.
4+
5+
Most VLMs utilize a Vision Transformer (ViT) as their multi-modal encoder, it is responsible for processing visual data, extracting features (objects, colors, textures, etc.), and transforming them into a format that can be understood by the model.
6+
7+
The text deocoder is based on LLM. It processes textual data and generates output based on the encoded visual features.
8+
9+
However, since the size of ViT is very small compared to language decoders,
10+
there is relatively little gain from TP. On the other hand, TP incurs significant communication
11+
overhead because of all-reduce being performed after every layer.
12+
13+
Placing the ViT in data parallel while keeping the LLM in tensor parallel consistently lowers TTFT and boosts end-to-end throughput. In this hybrid layout, the vision front-end becomes parallel and lightweight, while scarce interconnect bandwidth and collective ops are reserved for the LLM.
14+
15+
Data parallelism replicates the entire model across multiple GPU sets and processes different batches of requests in parallel.
16+
17+
## Pros and Cons for DP Multi-Modal Encoder
18+
19+
- Unfavorable compute/communication ratio for small ViTs
20+
ViTs used in multimodal stacks are typically modest in parameter count and activation sizes. TP introduces per-layer all-reduce collectives (attention/MLP) whose latency and synchronization overhead outweigh the speedup of splitting relatively small GEMMs. With DP, each GPU runs a full ViT locally—no inference-time collectives—so latency is dominated by compute, not wire time.
21+
22+
- Graph-capture gaps amplify TP overhead
23+
In production, the vision path often has dynamic shapes (pre/post-processing, variable resolution, patching) that break CUDA Graphs and limit torch.compile fusion. Without capture, we need to pay extra kernel-launch and framework overhead; TP then multiplies that cost with additional NCCL synchronizations. Keeping ViT in DP avoids layering collective latency on top of non-captured kernels.
24+
25+
- Better interconnect hygiene for the true bottleneck (the LLM)
26+
The LLM’s prefill and decode phases benefit materially from TP on fast links. Offloading ViT to DP eliminates “chatty” small collectives on the same fabric, reducing congestion and jitter for the LLM’s large, bandwidth-hungry all-reduces.
27+
28+
- Shorter and steadier critical path → lower TTFT
29+
TTFT ≈ T(image encode via ViT) + T(LLM prefill) + T(softmax/sample)
30+
DP has several advantages:
31+
(a) batch and prefetch ViT encodes independently,
32+
(b) overlap them with other requests’ LLM decodes on separate streams,
33+
(c) hand off compact visual embeddings to the TP LLM with minimal queuing.
34+
35+
- For vision encoders that use hardware-unoptimized Conv3D operations,
36+
batch-level DP can provide another 40% improvement compared to regular TP.
37+
38+
- Nevertheless, since the weights of the multi-modal encoder are replicated across each TP rank,
39+
there will be a minor increase in memory consumption and may cause OOM if you can barely fit the model already.
40+
41+
## Command Example
42+
You can enable batch-level DP by setting `mm-enable-dp-encoder`, for example:
43+
```
44+
SGLANG_MM_FEATURE_CACHE_MB=4096 \
45+
SGLANG_USE_CUDA_IPC_TRANSPORT=1 \
46+
SGLANG_VLM_CACHE_SIZE_MB=512 \
47+
python3 -m sglang.launch_server --host 127.0.0.1 \
48+
--mem-fraction-static 0.7 \
49+
--port 30000 \
50+
--trust-remote-code \
51+
--dtype auto \
52+
--max-running-requests 4 \
53+
--chunked-prefill-size 8192 \
54+
--attention-backend flashinfer \
55+
--tp 4 \
56+
--enable-multimodal \
57+
--chat-template internvl-2-5 \
58+
--model OpenGVLab/InternVL2_5-8B \
59+
--disable-radix-cache \
60+
--mm-enable-dp-encoder
61+
```
62+
!!! important
63+
Batch-level multi-modal DP is not to be confused with API request-level DP
64+
(which is instead controlled by `data_parallel_size`).
65+
66+
## Known supported models
67+
- Qwen2.5-VL (<https://github.com/sgl-project/sglang/pull/13126>)
68+
- Qwen3-VL (<https://github.com/sgl-project/sglang/pull/13724>)
69+
- InternVL (<https://github.com/sgl-project/sglang/pull/13925>)

0 commit comments

Comments
 (0)