Skip to content

Commit 6ae8ca6

Browse files
Merge branch 'main' into achatter/fp8_scaled_mm
2 parents 2b7d4a1 + 1bb6c78 commit 6ae8ca6

File tree

13 files changed

+570
-76
lines changed

13 files changed

+570
-76
lines changed

.github/workflows/pr-test-xpu.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,15 @@ jobs:
2727
docker build \
2828
--build-arg SG_LANG_KERNEL_BRANCH=${{ github.head_ref }} \
2929
--build-arg SG_LANG_KERNEL_REPO=${{ github.event.pull_request.head.repo.clone_url }} \
30-
--no-cache --progress=plain -f Dockerfile.xpu_kernel -t xpu_sglang:pvc .
30+
--no-cache --progress=plain -f Dockerfile.xpu_kernel -t xpu_sglang:kernel .
3131
3232
- name: Run container
3333
run: |
3434
docker run -dt \
3535
--device /dev/dri/ \
3636
--name ci_sglang_xpu \
3737
-e HF_TOKEN=$(cat ~/huggingface_token.txt) \
38-
xpu_sglang:pvc
38+
xpu_sglang:kernel
3939
4040
- name: Install Dependency
4141
timeout-minutes: 20
@@ -49,13 +49,13 @@ jobs:
4949
timeout-minutes: 20
5050
run: |
5151
docker exec -w /root/sglang ci_sglang_xpu \
52-
/bin/bash -c "cd /root/sglang/sgl-kernel-xpu/tests && python3 -m pytest -v -s test_awq_dequant.py test_topk_softmax.py test_flash_attention.py"
52+
/bin/bash -c "cd /root/sglang/sgl-kernel-xpu/tests && python3 run_suite.py --suite per-commit "
5353
5454
- name: Run Sglang Kernel Benchmarks
5555
timeout-minutes: 20
5656
run: |
5757
docker exec -w /root/sglang ci_sglang_xpu \
58-
/bin/bash -c "cd /root/sglang/sgl-kernel-xpu/benchmark && python3 bench_flash_attn.py "
58+
/bin/bash -c "cd /root/sglang/sgl-kernel-xpu/benchmark && python3 bench_flash_attn.py && python3 bench_moe_topk_softmax.py "
5959
6060
- name: Run E2E Bfloat16 tests
6161
timeout-minutes: 20

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ set(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable headers only mode in cutla
3838
FetchContent_Declare(
3939
repo-cutlass-sycl
4040
GIT_REPOSITORY https://github.com/intel/sycl-tla.git
41-
GIT_TAG d2292f0071125c32f92e8963f8dfba8ec3e491f7
41+
GIT_TAG 5a0b7a8b7024175f223f4a47535650f317bcbbf3
4242
GIT_SHALLOW OFF
4343
)
4444

Dockerfile.xpu_kernel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ RUN --mount=type=secret,id=github_token \
5555
cd sgl-kernel-xpu && \
5656
pip install -v . &&\
5757
# Install required packages for sglang workloads
58-
pip install msgspec blake3 py-cpuinfo compressed_tensors gguf partial_json_parser einops matplotlib pandas --root-user-action=ignore && \
58+
pip install msgspec blake3 py-cpuinfo compressed_tensors gguf partial_json_parser einops matplotlib pandas --root-user-action=ignore aiohttp && \
5959
conda install libsqlite=3.48.0 -y && \
6060
echo ". /miniforge3/bin/activate; conda activate py${PYTHON_VERSION}; cd /root/" >> /root/.bashrc;
6161

benchmark/bench_flash_attn.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def flash_attn_baseline(
1212
causal,
1313
window_size,
1414
softmax_scale,
15-
softmax_sink,
15+
sinks,
1616
cache_seqlens,
1717
page_table,
1818
cu_seqlens_q,
@@ -24,7 +24,7 @@ def flash_attn_baseline(
2424
k_cache,
2525
v_cache,
2626
causal=causal,
27-
softmax_sink=softmax_sink,
27+
sinks=sinks,
2828
window_size=window_size,
2929
softmax_scale=softmax_scale,
3030
page_table=page_table,
@@ -39,7 +39,7 @@ def flash_attn_baseline(
3939
# Benchmark configurations
4040
causal = [True, False]
4141
local = [True, False]
42-
use_softmax_sink = [True, False]
42+
use_sinks = [True, False]
4343
batch_size = [1, 16]
4444
q_seq_length_range = [1, 512, 1024]
4545
kv_seq_length_range = [512, 1024, 2048, 4096, 8192, 16384]
@@ -50,7 +50,7 @@ def flash_attn_baseline(
5050
product(
5151
causal,
5252
local,
53-
use_softmax_sink,
53+
use_sinks,
5454
batch_size,
5555
q_seq_length_range,
5656
kv_seq_length_range,
@@ -65,7 +65,7 @@ def flash_attn_baseline(
6565
x_names=[
6666
"causal",
6767
"local",
68-
"use_softmax_sink",
68+
"use_sinks",
6969
"batch_size",
7070
"q_seq_length",
7171
"kv_seq_length",
@@ -84,7 +84,7 @@ def flash_attn_baseline(
8484
def benchmark(
8585
causal,
8686
local,
87-
use_softmax_sink,
87+
use_sinks,
8888
batch_size,
8989
q_seq_length,
9090
kv_seq_length,
@@ -127,9 +127,7 @@ def benchmark(
127127
max_seqlen_q = q_seq_length
128128
window_size = (-1, -1) if not local else torch.randint(0, kv_seq_length, (2,))
129129

130-
softmax_sink = (
131-
torch.randn(num_heads, device=device, dtype=dtype) if use_softmax_sink else None
132-
)
130+
sinks = torch.randn(num_heads, device=device, dtype=dtype) if use_sinks else None
133131

134132
softmax_scale = 1.0 / (head_dim**0.5)
135133

@@ -144,7 +142,7 @@ def benchmark(
144142
causal=causal,
145143
window_size=window_size,
146144
softmax_scale=softmax_scale,
147-
softmax_sink=softmax_sink,
145+
sinks=sinks,
148146
cache_seqlens=cache_seqlens,
149147
page_table=page_table,
150148
cu_seqlens_q=cu_seqlens_q,

benchmark/bench_moe_topk_softmax.py

Lines changed: 90 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
import triton
55
from sgl_kernel import topk_softmax
6+
from utils import get_model_config, parse_args
67

78

89
def vllm_topk_softmax(gating_output, topk):
@@ -23,7 +24,35 @@ def vllm_topk_softmax(gating_output, topk):
2324
return topk_weights, topk_indices
2425

2526

26-
def sglang_topk_softmax(gating_output, topk):
27+
def navtive_topk_softmax(
28+
gating_output: torch.Tensor,
29+
topk: int,
30+
renormalize: bool,
31+
):
32+
num_tokens, num_experts = gating_output.shape
33+
34+
import torch.nn.functional as F
35+
36+
topk_weights = torch.empty(
37+
(num_tokens, topk), device=gating_output.device, dtype=torch.float32
38+
)
39+
topk_indices = torch.empty(
40+
(num_tokens, topk), dtype=torch.int32, device=gating_output.device
41+
)
42+
topk_weights = F.softmax(gating_output.float(), dim=-1)
43+
topk_weights, topk_indices = torch.topk(topk_weights, topk, dim=-1)
44+
45+
if renormalize:
46+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
47+
48+
return topk_weights, topk_indices
49+
50+
51+
def sglang_topk_softmax(
52+
gating_output: torch.Tensor,
53+
topk: int,
54+
renormalize: bool,
55+
):
2756
num_tokens, num_experts = gating_output.shape
2857

2958
topk_weights = torch.empty(
@@ -37,18 +66,18 @@ def sglang_topk_softmax(gating_output, topk):
3766
)
3867

3968
topk_softmax(
40-
topk_weights=topk_weights,
41-
topk_ids=topk_indices,
42-
token_expert_indices=token_expert_indices,
43-
gating_output=gating_output,
69+
topk_weights,
70+
topk_indices,
71+
gating_output,
72+
renormalize=renormalize,
4473
)
4574

4675
return topk_weights, topk_indices
4776

4877

4978
def calculate_diff(num_tokens, num_experts, topk):
5079
gating_output = torch.randn(
51-
(num_tokens, num_experts), device="cuda", dtype=torch.float32
80+
(num_tokens, num_experts), device=gating_output.device, dtype=torch.float32
5281
)
5382
weights_vllm, indices_vllm = vllm_topk_softmax(gating_output.clone(), topk)
5483
weights_sglang, indices_sglang = sglang_topk_softmax(gating_output.clone(), topk)
@@ -67,52 +96,67 @@ def calculate_diff(num_tokens, num_experts, topk):
6796
)
6897

6998

70-
num_tokens_range = [128, 512, 1024, 2048, 4096, 8192, 16384, 32768]
71-
num_experts_range = [32, 64, 128, 256, 12, 512]
72-
topk_range = [1, 2, 4, 8]
99+
def get_benchmark(device="xpu"):
100+
@triton.testing.perf_report(
101+
triton.testing.Benchmark(
102+
x_names=["num_tokens", "num_experts", "topk", "dtype", "renormalize"],
103+
x_vals=configs,
104+
line_arg="provider",
105+
line_vals=["sglang", "native"],
106+
line_names=["SGLang", "native"],
107+
styles=[("blue", "-"), ("green", "-")],
108+
ylabel="Latency (us)",
109+
plot_name="topk-softmax-performance",
110+
args={},
111+
)
112+
)
113+
def benchmark(num_tokens, num_experts, topk, dtype, renormalize, provider):
73114

74-
configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range))
115+
gating_output = torch.randn(
116+
(num_tokens, num_experts), device=device, dtype=dtype
117+
)
75118

119+
if provider == "sglang" or provider == "sglang1":
120+
fn = lambda: sglang_topk_softmax(gating_output, topk, renormalize)
121+
elif provider == "native":
122+
fn = lambda: navtive_topk_softmax(gating_output, topk, renormalize)
76123

77-
@triton.testing.perf_report(
78-
triton.testing.Benchmark(
79-
x_names=["num_tokens", "num_experts", "topk"],
80-
x_vals=configs,
81-
line_arg="provider",
82-
line_vals=["sglang", "vllm"],
83-
line_names=["SGLang", "VLLM"],
84-
styles=[("blue", "-"), ("green", "-")],
85-
ylabel="Latency (us)",
86-
plot_name="topk-softmax-performance",
87-
args={},
88-
)
89-
)
90-
def benchmark(num_tokens, num_experts, topk, provider):
124+
quantiles = [0.5, 0.2, 0.8]
125+
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
91126

92-
gating_output = torch.randn(
93-
(num_tokens, num_experts), device="cuda", dtype=torch.float32
94-
)
127+
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
128+
129+
return benchmark
95130

96-
if provider == "vllm" or provider == "vllm1":
97-
fn = lambda: vllm_topk_softmax(gating_output, topk)
98-
elif provider == "sglang" or provider == "sglang1":
99-
fn = lambda: sglang_topk_softmax(gating_output, topk)
100131

101-
quantiles = [0.5, 0.2, 0.8]
102-
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
132+
if __name__ == "__main__":
133+
# Run correctness test on small configs if not using a real model
134+
args = parse_args()
135+
params = get_model_config(args)
136+
137+
sweep_params = {
138+
"num_tokens": args.num_tokens,
139+
"num_experts": params["num_experts"] or [64],
140+
"top_k": params["top_k"] or [2, 4],
141+
"dtype": [torch.bfloat16],
142+
"renormalize": [False],
143+
}
144+
145+
keys = sweep_params.keys()
146+
configs = list(itertools.product(*sweep_params.values()))
147+
print(f"Testing {len(configs)} configurations...")
148+
for config in configs:
149+
num_tokens, num_experts, topk, dtype, renormalize = config
150+
print(
151+
f"Config: num_tokens={num_tokens}, num_experts={num_experts}, topk={topk}, dtype={dtype}, renormalize={renormalize}"
152+
)
103153

104-
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
154+
# calculate_diff(num_tokens, num_experts, topk)
105155

156+
global benchmark_configs
157+
benchmark_configs = configs
106158

107-
if __name__ == "__main__":
108-
configs = [
109-
(20, 256, 4),
110-
(20, 256, 8),
111-
(20, 12, 4),
112-
(20, 12, 1),
113-
(20, 512, 4),
114-
(20, 512, 1),
115-
]
116-
for num_tokens, num_experts, topk in configs:
117-
calculate_diff(num_tokens, num_experts, topk)
118-
benchmark.run(print_data=True)
159+
# Run benchmark
160+
print("Starting performance benchmark...")
161+
benchmark = get_benchmark()
162+
benchmark.run(print_data=True, show_plots=False, save_path=".")

0 commit comments

Comments
 (0)