Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .buildkite/models/google_gemma-3-27b-it.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ steps:
commands:
- |
.buildkite/scripts/run_in_docker.sh \
bash -c 'SKIP_JAX_PRECOMPILE=1 VLLM_XLA_CHECK_RECOMPILATION=0 python3 /workspace/tpu_inference/examples/offline_inference.py --model=google/gemma-3-27b-it --tensor_parallel_size=8 --task=generate --max_model_len=1024 --max_num_seqs=1'
bash -c 'VLLM_XLA_CHECK_RECOMPILATION=0 python3 /workspace/tpu_inference/examples/offline_inference.py --model=google/gemma-3-27b-it --tensor_parallel_size=8 --task=generate --max_model_len=1024 --max_num_seqs=1'
- label: "Record unit test result for google/gemma-3-27b-it"
key: "record_google_gemma-3-27b-it_UnitTest"
depends_on: "google_gemma-3-27b-it_UnitTest"
Expand Down
4 changes: 2 additions & 2 deletions examples/disagg/run_disagg_multi_host.sh
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ for ((i=0; i<NUM_HOSTS_PER_INSTANCE; i++)); do
-e TPU_KV_TRANSFER_PORT="${KV_PORT}" \
-e TPU_SIDE_CHANNEL_PORT="${SIDE_PORT}" \
-e RAY_DEDUP_LOGS="0" \
-e SKIP_JAX_PRECOMPILE="1" \
\
-e TPU_CHIPS_PER_PROCESS_BOUNDS="1,1,1" \
-e TPU_PROCESS_BOUNDS="2,2,1" \
Expand Down Expand Up @@ -95,6 +94,7 @@ docker exec node-0 /bin/bash -c \
--gpu-memory-utilization 0.3 \
--tensor-parallel-size 4 \
--kv-transfer-config '{\"kv_connector\":\"TPUConnector\",\"kv_connector_module_path\":\"tpu_inference.distributed.tpu_connector\",\"kv_role\":\"kv_producer\"}' \
--enforce-eager \
> /root/logs/prefill.txt 2>&1 &"
set +x

Expand Down Expand Up @@ -137,7 +137,6 @@ for ((i=0; i<NUM_HOSTS_PER_INSTANCE; i++)); do
-e TPU_KV_TRANSFER_PORT="${KV_PORT}" \
-e TPU_SIDE_CHANNEL_PORT="${SIDE_PORT}" \
-e RAY_DEDUP_LOGS="0" \
-e SKIP_JAX_PRECOMPILE="1" \
\
-e TPU_CHIPS_PER_PROCESS_BOUNDS="1,1,1" \
-e TPU_PROCESS_BOUNDS="2,2,1" \
Expand Down Expand Up @@ -169,5 +168,6 @@ docker exec node-20 /bin/bash -c \
--gpu-memory-utilization 0.3 \
--tensor-parallel-size 4 \
--kv-transfer-config '{\"kv_connector\":\"TPUConnector\",\"kv_connector_module_path\":\"tpu_inference.distributed.tpu_connector\",\"kv_role\":\"kv_consumer\"}' \
--enforce-eager \
> /root/logs/decode.txt 2>&1 &"
set +x
4 changes: 2 additions & 2 deletions examples/disagg/run_disagg_single_host.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ for i in $(seq 0 $((NUM_PREFILL_INSTANCES-1))); do
\
TPU_KV_TRANSFER_PORT=$KV_PORT \
TPU_SIDE_CHANNEL_PORT=$SIDE_PORT \
SKIP_JAX_PRECOMPILE=1 \
\
vllm serve $MODEL \
--port $PORT \
--gpu-memory-utilization 0.2 \
--tensor-parallel-size $PREFILLER_TP_SIZE \
--kv-transfer-config "{\"kv_connector\":\"TPUConnector\",\"kv_connector_module_path\":\"tpu_inference.distributed.tpu_connector\",\"kv_role\":\"kv_producer\"}" \
--enforce-eager \
> $HOME/logs/prefill_$i.txt 2>&1 &

PREFILL_HOSTS+=("localhost")
Expand All @@ -72,13 +72,13 @@ for i in $(seq 0 $((NUM_DECODE_INSTANCES-1))); do
\
TPU_KV_TRANSFER_PORT=$KV_PORT \
TPU_SIDE_CHANNEL_PORT=$SIDE_PORT \
SKIP_JAX_PRECOMPILE=1 \
\
vllm serve $MODEL \
--port $PORT \
--gpu-memory-utilization 0.2 \
--tensor-parallel-size $DECODER_TP_SIZE \
--kv-transfer-config "{\"kv_connector\":\"TPUConnector\",\"kv_connector_module_path\":\"tpu_inference.distributed.tpu_connector\",\"kv_role\":\"kv_consumer\"}" \
--enforce-eager \
> $HOME/logs/decode_$i.txt 2>&1 &

DECODE_HOSTS+=("localhost")
Expand Down
8 changes: 3 additions & 5 deletions examples/offline_inference.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import os

import vllm.envs as envs
from vllm import LLM, EngineArgs
from vllm.utils.argparse_utils import FlexibleArgumentParser
Expand All @@ -17,6 +15,9 @@ def create_parser():
parser.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
parser.set_defaults(max_model_len=1024)

# Skip long warmup for local simple test.
parser.set_defaults(enforce_eager=True)

# Add sampling params
sampling_group = parser.add_argument_group("Sampling parameters")
sampling_group.add_argument("--max-tokens", type=int)
Expand Down Expand Up @@ -103,9 +104,6 @@ def main(args: dict):


if __name__ == "__main__":
# Skip long warmup for local simple test.
os.environ['SKIP_JAX_PRECOMPILE'] = '1'

parser = create_parser()
args: dict = vars(parser.parse_args())

Expand Down
7 changes: 3 additions & 4 deletions examples/offline_lora_inference.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import os
import time

import vllm.envs as envs
Expand All @@ -20,6 +19,9 @@ def create_parser():
parser.set_defaults(enable_lora=True)
parser.set_defaults(max_lora_rank=8)

# Skip long warmup for local simple test.
parser.set_defaults(enforce_eager=True)

# Add sampling params
sampling_group = parser.add_argument_group("Sampling parameters")
sampling_group.add_argument("--max-tokens", type=int, default=16)
Expand Down Expand Up @@ -76,9 +78,6 @@ def main(args: dict):


if __name__ == "__main__":
# Skip long warmup for local simple test.
os.environ['SKIP_JAX_PRECOMPILE'] = '1'

parser = create_parser()
args: dict = vars(parser.parse_args())

Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/benchmarking/mm_bench.sh
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ checkThroughputAndRouge() {
}

echo "Spinning up the vLLM server..."
(SKIP_JAX_PRECOMPILE=1 VLLM_XLA_CHECK_RECOMPILATION=0 vllm serve "$model_name" --max-model-len "$max_model_len" --max-num-seqs "$max_num_seqs" --disable-log-requests --max-num-batched-tokens "$max_batched_tokens" 2>&1 | tee -a "$LOG_FILE") &
(VLLM_XLA_CHECK_RECOMPILATION=0 vllm serve "$model_name" --max-model-len "$max_model_len" --max-num-seqs "$max_num_seqs" --disable-log-requests --max-num-batched-tokens "$max_batched_tokens" --enforce-eager 2>&1 | tee -a "$LOG_FILE") &


# Run a busy loop to block until the server is ready to receive requests
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/test_multi_modal_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def test_multi_modal_inference(monkeypatch):
"""
Runs multi-modal inference and verifies the output.
"""
os.environ['SKIP_JAX_PRECOMPILE'] = '1' # Skip warmup to save time.
os.environ[
'VLLM_XLA_CHECK_RECOMPILATION'] = '0' # Allow compilation during execution.

Expand Down Expand Up @@ -65,6 +64,7 @@ def test_multi_modal_inference(monkeypatch):
"fps": 1,
},
limit_mm_per_prompt={modality: 1},
enforce_eager=True, # Skip warmup to save time.
)
engine_args = asdict(engine_args)
llm = LLM(**engine_args)
Expand Down
4 changes: 1 addition & 3 deletions tpu_inference/runner/compilation_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import time
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple

Expand Down Expand Up @@ -67,8 +66,7 @@ def _run_compilation(self, name: str, fn: Callable, *args,
logger.info("Compilation finished in %.2f [secs].", end - start)

def capture_model(self) -> None:
if os.getenv("SKIP_JAX_PRECOMPILE",
False) or self.runner.model_config.enforce_eager:
if self.runner.model_config.enforce_eager:
return
logger.info("Precompile all the subgraphs with possible input shapes.")

Expand Down