Skip to content

Commit 6d5e0a5

Browse files
authored
Merge branch 'main' into enable_autorun_on_ready_for_eval
2 parents c5c9d34 + 155ad56 commit 6d5e0a5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+1193
-434
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,9 @@ steps:
473473
- tests/compile
474474
commands:
475475
- pytest -v -s compile/test_full_graph.py
476-
- pytest -v -s compile/test_fusions_e2e.py
476+
# Limit to no custom ops to reduce running time
477+
# Wrap with quotes to escape yaml and avoid starting -k string with a -
478+
- "pytest -v -s compile/test_fusions_e2e.py -k 'TRITON and -quant_fp8'"
477479

478480
- label: Cudagraph test
479481
timeout_in_minutes: 20
@@ -930,6 +932,29 @@ steps:
930932
- pytest -v -s tests/compile/test_silu_mul_quant_fusion.py
931933
# this runner has 2 GPUs available even though num_gpus=2 is not set
932934
- pytest -v -s tests/compile/test_fusion_all_reduce.py
935+
# Limit to Inductor partition, no custom ops, and allreduce & attn fusion to reduce running time
936+
# Wrap with quotes to escape yaml
937+
- "pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and Llama-3.1 and -quant_fp8 and -rms_norm'"
938+
939+
- label: Blackwell Fusion E2E Tests # 30 min
940+
timeout_in_minutes: 40
941+
working_dir: "/vllm-workspace/"
942+
gpu: b200
943+
optional: true
944+
num_gpus: 2
945+
source_file_dependencies:
946+
- csrc/quantization/fp4/
947+
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
948+
- vllm/v1/attention/backends/flashinfer.py
949+
- vllm/compilation/
950+
# can affect pattern matching
951+
- vllm/model_executor/layers/layernorm.py
952+
- vllm/model_executor/layers/activation.py
953+
- vllm/model_executor/layers/quantization/input_quant_fp8.py
954+
- tests/compile/test_fusions_e2e.py
955+
commands:
956+
- nvidia-smi
957+
# Run all e2e fusion tests
933958
- pytest -v -s tests/compile/test_fusions_e2e.py
934959

935960
- label: Blackwell GPT-OSS Eval

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ vLLM is flexible and easy to use with:
8484
- Tensor, pipeline, data and expert parallelism support for distributed inference
8585
- Streaming outputs
8686
- OpenAI-compatible API server
87-
- Support for NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, and TPU. Additionally, support for diverse hardware plugins such as Intel Gaudi, IBM Spyre and Huawei Ascend.
87+
- Support for NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, Arm CPUs, and TPU. Additionally, support for diverse hardware plugins such as Intel Gaudi, IBM Spyre and Huawei Ascend.
8888
- Prefix caching support
8989
- Multi-LoRA support
9090

csrc/quantization/fp4/nvfp4_quant_kernels.cu

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@
3131

3232
namespace vllm {
3333

34+
template <typename Int>
35+
__host__ __device__ inline Int round_up(Int x, Int y) {
36+
static_assert(std::is_integral_v<Int>,
37+
"round_up argument must be integral type");
38+
return (x + y - 1) / y * y;
39+
}
40+
3441
// Use UE4M3 by default.
3542
template <class Type, bool UE8M0_SF = false>
3643
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
@@ -42,10 +49,21 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
4249
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
4350
"Vec size is not matched.");
4451

52+
int sf_m = round_up<int>(numRows, 128);
53+
int sf_n_unpadded = numCols / CVT_FP4_SF_VEC_SIZE;
54+
int sf_n_int = round_up<int>(sf_n_unpadded, 4) / 4;
55+
for (int row = numRows + blockIdx.x; row < sf_m; row += gridDim.x) {
56+
// Each thread writes 4 uint32_t elements.
57+
for (int col = sf_n_unpadded + threadIdx.x * 4; col < sf_n_int;
58+
col += blockDim.x * 4) {
59+
SFout[row * sf_n_int + col] = 0x00;
60+
}
61+
}
62+
4563
// Get the global scaling factor, which will be applied to the SF.
4664
// Note SFScale is the same as next GEMM's alpha, which is
4765
// (448.f / (Alpha_A / 6.f)).
48-
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0];
66+
float const global_scale = SFScale == nullptr ? 1.0f : SFScale[0];
4967

5068
// Input tensor row/col loops.
5169
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
@@ -64,7 +82,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
6482
rowIdx, colIdx, numCols, SFout);
6583

6684
out_pos =
67-
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
85+
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, global_scale, sf_out);
6886
}
6987
}
7088
}

docs/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ vLLM is flexible and easy to use with:
5656
- Tensor, pipeline, data and expert parallelism support for distributed inference
5757
- Streaming outputs
5858
- OpenAI-compatible API server
59-
- Support for NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, Arm CPUs and TPU. Additionally, support for diverse hardware plugins such as Intel Gaudi, IBM Spyre and Huawei Ascend.
59+
- Support for NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, Arm CPUs, and TPU. Additionally, support for diverse hardware plugins such as Intel Gaudi, IBM Spyre and Huawei Ascend.
6060
- Prefix caching support
6161
- Multi-LoRA support
6262

docs/usage/troubleshooting.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,10 @@ Traceback (most recent call last):
316316

317317
This indicates vLLM failed to initialize the NCCL communicator, possibly due to a missing `IPC_LOCK` linux capability or an unmounted `/dev/shm`. Refer to [Enabling GPUDirect RDMA](../serving/parallelism_scaling.md#enabling-gpudirect-rdma) for guidance on properly configuring the environment for GPUDirect RDMA.
318318

319+
## CUDA error: the provided PTX was compiled with an unsupported toolchain
320+
321+
If you see an error like `RuntimeError: CUDA error: the provided PTX was compiled with an unsupported toolchain.`, it means that the CUDA PTX in vLLM's wheels was compiled with a toolchain unsupported by your system. The released vLLM wheels have to be compiled with a specific version of CUDA toolkit, and the compiled code might fail to run on lower versions of CUDA drivers. Read [cuda compatibility](https://docs.nvidia.com/deploy/cuda-compatibility/) for more details. The solution is to install `cuda-compat` package from your package manager. For example, on Ubuntu, you can run `sudo apt-get install cuda-compat-12-9`, and then add `export LD_LIBRARY_PATH=/usr/local/cuda-12.9/compat:$LD_LIBRARY_PATH` to your `.bashrc` file. When successfully installed, you should see that the output of `nvidia-smi` will show `CUDA Version: 12.9`. Note that we use CUDA 12.9 as an example here, you may want to install a higher version of cuda-compat package in case vLLM's default CUDA version goes higher.
322+
319323
## Known Issues
320324

321325
- In `v0.5.2`, `v0.5.3`, and `v0.5.3.post1`, there is a bug caused by [zmq](https://github.com/zeromq/pyzmq/issues/2000) , which can occasionally cause vLLM to hang depending on the machine configuration. The solution is to upgrade to the latest version of `vllm` to include the [fix](https://github.com/vllm-project/vllm/pull/6759).

tests/compile/test_fusions_e2e.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,11 @@ class ModelBackendTestCase(NamedTuple):
5454

5555
MODELS_FP4 = [
5656
ModelBackendTestCase(
57-
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
57+
model_name="nvidia/Llama-3.1-8B-Instruct-FP4",
5858
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
5959
backend=_Backend.FLASHINFER,
60-
attention_fusions=48,
61-
allreduce_fusions=96,
60+
attention_fusions=32,
61+
allreduce_fusions=65,
6262
),
6363
]
6464

@@ -95,8 +95,7 @@ class ModelBackendTestCase(NamedTuple):
9595
),
9696
]
9797

98-
# TODO(luka) test both in nightly
99-
CUSTOM_OPS_FP8 = ["-quant_fp8"] # , "+quant_fp8"]
98+
CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"]
10099

101100

102101
@pytest.mark.parametrize(
@@ -171,8 +170,7 @@ def test_attn_quant(
171170
assert int(matches[0]) == attention_fusions
172171

173172

174-
# TODO(luka) test both in nightly
175-
CUSTOM_OPS_RMS_NORM = ["-rms_norm"] # , "+rms_norm"]
173+
CUSTOM_OPS_RMS_NORM = ["-rms_norm", "+rms_norm"]
176174

177175

178176
def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:

tests/kernels/quantization/test_nvfp4_quant.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,7 @@ def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None:
168168
out_ref, scale_ref = ref_nvfp4_quant(x, global_scale)
169169

170170
out, out_scale = ops.scaled_fp4_quant(x, global_scale)
171-
172171
scale_ans = recover_swizzled_scales(out_scale, m, n)
173172
out_ans = cast_from_fp4(out, m, n)
174-
175173
torch.testing.assert_close(out_ans, out_ref)
176174
torch.testing.assert_close(scale_ans, scale_ref)

tests/samplers/test_logprobs.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import pytest
5+
6+
from vllm import SamplingParams
7+
from vllm.logprobs import FlattenLogprobs
8+
9+
MODELS = ["distilbert/distilgpt2"]
10+
MAX_TOKENS = 5
11+
NUM_TOP_LOGPROBS = 5
12+
NUM_PROMPT_LOGPROBS = 7
13+
MAX_LOGPROBS = max(NUM_TOP_LOGPROBS, NUM_PROMPT_LOGPROBS)
14+
15+
16+
@pytest.mark.parametrize("model", MODELS)
17+
@pytest.mark.parametrize("dtype", ["half"])
18+
@pytest.mark.parametrize("greedy", [True, False])
19+
@pytest.mark.parametrize("flatten_logprobs", [True, False])
20+
def test_ranks(
21+
vllm_runner,
22+
model,
23+
dtype,
24+
greedy,
25+
flatten_logprobs,
26+
example_prompts,
27+
monkeypatch: pytest.MonkeyPatch,
28+
):
29+
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1" if flatten_logprobs else "0")
30+
with vllm_runner(model, dtype=dtype, max_logprobs=MAX_LOGPROBS) as vllm_model:
31+
tokenizer = vllm_model.llm.get_tokenizer()
32+
example_prompt_tokens = [tokenizer.encode(prompt) for prompt in example_prompts]
33+
sampling_params = SamplingParams(
34+
temperature=0.0 if greedy else 1.0,
35+
top_p=1.0,
36+
max_tokens=MAX_TOKENS,
37+
logprobs=NUM_TOP_LOGPROBS,
38+
prompt_logprobs=NUM_PROMPT_LOGPROBS,
39+
)
40+
results = vllm_model.generate_w_logprobs(example_prompts, sampling_params)
41+
42+
assert len(results) == len(example_prompt_tokens)
43+
for i, (result, prompt_tokens) in enumerate(zip(results, example_prompt_tokens)):
44+
decode_tokens, _, decode_logprobs, prompt_logprobs = result
45+
46+
# Ensure the return type of logprobs is accurate
47+
assert isinstance(
48+
prompt_logprobs, FlattenLogprobs if flatten_logprobs else list
49+
)
50+
assert isinstance(
51+
decode_logprobs, FlattenLogprobs if flatten_logprobs else list
52+
)
53+
54+
########################
55+
# Check prompt logprobs
56+
########################
57+
assert len(prompt_tokens) == len(prompt_logprobs)
58+
# No logprob for first prompt token
59+
assert not prompt_logprobs[0]
60+
for position, (token, logprobs) in enumerate(
61+
zip(prompt_tokens[1:], prompt_logprobs[1:]), start=1
62+
):
63+
# Ensure logprobs of prompt token is always returned
64+
logprob = logprobs.get(token)
65+
assert logprob is not None
66+
assert logprob.rank >= 1
67+
# Ensure # of returned logprobs should be
68+
# either NUM_PROMPT_LOGPROBS or NUM_PROMPT_LOGPROBS+1
69+
assert NUM_PROMPT_LOGPROBS <= len(logprobs) <= NUM_PROMPT_LOGPROBS + 1
70+
# Ensure top NUM_PROMPT_LOGPROBS is always extracted
71+
assert set(range(1, NUM_PROMPT_LOGPROBS + 1)).issubset(
72+
{logprob.rank for logprob in logprobs.values()}
73+
)
74+
75+
########################
76+
# Check sample logprobs
77+
########################
78+
assert len(decode_tokens) == len(decode_logprobs)
79+
for position, (token, logprobs) in enumerate(
80+
zip(decode_tokens, decode_logprobs)
81+
):
82+
# Ensure logprobs of chosen token is always returned
83+
logprob = logprobs.get(token)
84+
assert logprob is not None
85+
if greedy:
86+
# For greedy sampling, all chosen logprob should be top ranked
87+
assert logprob.rank == 1
88+
else:
89+
assert logprob.rank >= 1
90+
# Ensure # of returned logprobs should be
91+
# either NUM_TOP_LOGPROBS or NUM_TOP_LOGPROBS+1
92+
assert NUM_TOP_LOGPROBS <= len(logprobs) <= NUM_TOP_LOGPROBS + 1
93+
# Ensure top NUM_TOP_LOGPROBS logprobs is always extracted
94+
assert set(range(1, NUM_TOP_LOGPROBS + 1)).issubset(
95+
{logprob.rank for logprob in logprobs.values()}
96+
)

tests/samplers/test_ranks.py

Lines changed: 0 additions & 59 deletions
This file was deleted.

0 commit comments

Comments
 (0)