Skip to content

Commit 26c3f42

Browse files
xingliu14AahilA
authored andcommitted
Add test_envs.py (vllm-project#1079)
Signed-off-by: Xing Liu <[email protected]>
1 parent 00cb97c commit 26c3f42

File tree

1 file changed

+182
-0
lines changed

1 file changed

+182
-0
lines changed

tests/test_envs.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the tpu-inference project
3+
4+
import pytest
5+
6+
import tpu_inference.envs as envs
7+
from tpu_inference.envs import enable_envs_cache, environment_variables
8+
9+
10+
def test_getattr_without_cache(monkeypatch: pytest.MonkeyPatch):
11+
assert envs.JAX_PLATFORMS == ""
12+
assert envs.PHASED_PROFILING_DIR == ""
13+
monkeypatch.setenv("JAX_PLATFORMS", "tpu")
14+
monkeypatch.setenv("PHASED_PROFILING_DIR", "/tmp/profiling")
15+
assert envs.JAX_PLATFORMS == "tpu"
16+
assert envs.PHASED_PROFILING_DIR == "/tmp/profiling"
17+
18+
assert envs.TPU_NAME is None
19+
assert envs.TPU_ACCELERATOR_TYPE is None
20+
monkeypatch.setenv("TPU_NAME", "my-tpu")
21+
monkeypatch.setenv("TPU_ACCELERATOR_TYPE", "v5litepod-16")
22+
assert envs.TPU_NAME == "my-tpu"
23+
assert envs.TPU_ACCELERATOR_TYPE == "v5litepod-16"
24+
25+
# __getattr__ is not decorated with functools.cache
26+
assert not hasattr(envs.__getattr__, "cache_info")
27+
28+
29+
def test_getattr_with_cache(monkeypatch: pytest.MonkeyPatch):
30+
monkeypatch.setenv("JAX_PLATFORMS", "tpu")
31+
monkeypatch.setenv("TPU_NAME", "my-tpu")
32+
33+
# __getattr__ is not decorated with functools.cache
34+
assert not hasattr(envs.__getattr__, "cache_info")
35+
36+
enable_envs_cache()
37+
38+
# __getattr__ is decorated with functools.cache
39+
assert hasattr(envs.__getattr__, "cache_info")
40+
start_hits = envs.__getattr__.cache_info().hits
41+
42+
# 2 more hits due to JAX_PLATFORMS and TPU_NAME accesses
43+
assert envs.JAX_PLATFORMS == "tpu"
44+
assert envs.TPU_NAME == "my-tpu"
45+
assert envs.__getattr__.cache_info().hits == start_hits + 2
46+
47+
# All environment variables are cached
48+
for environment_variable in environment_variables:
49+
envs.__getattr__(environment_variable)
50+
assert envs.__getattr__.cache_info(
51+
).hits == start_hits + 2 + len(environment_variables)
52+
53+
# Reset envs.__getattr__ back to non-cached version to
54+
# avoid affecting other tests
55+
envs.__getattr__ = envs.__getattr__.__wrapped__
56+
57+
58+
def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
59+
# Test SKIP_JAX_PRECOMPILE (default False)
60+
assert envs.SKIP_JAX_PRECOMPILE is False
61+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "1")
62+
assert envs.SKIP_JAX_PRECOMPILE is True
63+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0")
64+
assert envs.SKIP_JAX_PRECOMPILE is False
65+
66+
# Test NEW_MODEL_DESIGN (default False)
67+
assert envs.NEW_MODEL_DESIGN is False
68+
monkeypatch.setenv("NEW_MODEL_DESIGN", "1")
69+
assert envs.NEW_MODEL_DESIGN is True
70+
71+
# Test USE_MOE_EP_KERNEL (default False)
72+
assert envs.USE_MOE_EP_KERNEL is False
73+
monkeypatch.setenv("USE_MOE_EP_KERNEL", "1")
74+
assert envs.USE_MOE_EP_KERNEL is True
75+
76+
77+
def test_integer_env_vars(monkeypatch: pytest.MonkeyPatch):
78+
assert envs.PYTHON_TRACER_LEVEL == 1
79+
monkeypatch.setenv("PYTHON_TRACER_LEVEL", "3")
80+
assert envs.PYTHON_TRACER_LEVEL == 3
81+
monkeypatch.setenv("PYTHON_TRACER_LEVEL", "0")
82+
assert envs.PYTHON_TRACER_LEVEL == 0
83+
84+
85+
def test_lowercase_conversion(monkeypatch: pytest.MonkeyPatch):
86+
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "GRPC")
87+
assert envs.TPU_MULTIHOST_BACKEND == "grpc"
88+
89+
monkeypatch.setenv("MODEL_IMPL_TYPE", "FLAX_NNX")
90+
assert envs.MODEL_IMPL_TYPE == "flax_nnx"
91+
92+
93+
def test_string_env_vars_defaults(monkeypatch: pytest.MonkeyPatch):
94+
monkeypatch.delenv("JAX_PLATFORMS", raising=False)
95+
monkeypatch.delenv("PREFILL_SLICES", raising=False)
96+
monkeypatch.delenv("DECODE_SLICES", raising=False)
97+
98+
assert envs.JAX_PLATFORMS == ""
99+
assert envs.PREFILL_SLICES == ""
100+
assert envs.DECODE_SLICES == ""
101+
assert envs.PHASED_PROFILING_DIR == ""
102+
103+
104+
def test_none_default_env_vars(monkeypatch: pytest.MonkeyPatch):
105+
monkeypatch.delenv("TPU_ACCELERATOR_TYPE", raising=False)
106+
monkeypatch.delenv("TPU_NAME", raising=False)
107+
monkeypatch.delenv("TPU_WORKER_ID", raising=False)
108+
109+
assert envs.TPU_ACCELERATOR_TYPE is None
110+
assert envs.TPU_NAME is None
111+
assert envs.TPU_WORKER_ID is None
112+
113+
114+
def test_ray_env_vars(monkeypatch: pytest.MonkeyPatch):
115+
assert envs.RAY_USAGE_STATS_ENABLED == "0"
116+
monkeypatch.setenv("RAY_USAGE_STATS_ENABLED", "1")
117+
assert envs.RAY_USAGE_STATS_ENABLED == "1"
118+
119+
assert envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "shm"
120+
monkeypatch.setenv("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "nccl")
121+
assert envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl"
122+
123+
124+
def test_invalid_attribute_raises_error():
125+
with pytest.raises(AttributeError,
126+
match="has no attribute 'NONEXISTENT_VAR'"):
127+
_ = envs.NONEXISTENT_VAR
128+
129+
130+
def test_dir_returns_all_env_vars():
131+
env_vars = envs.__dir__()
132+
assert isinstance(env_vars, list)
133+
assert len(env_vars) == len(environment_variables)
134+
assert "JAX_PLATFORMS" in env_vars
135+
assert "TPU_NAME" in env_vars
136+
assert "SKIP_JAX_PRECOMPILE" in env_vars
137+
assert "MODEL_IMPL_TYPE" in env_vars
138+
139+
140+
def test_tpu_multihost_env_vars(monkeypatch: pytest.MonkeyPatch):
141+
monkeypatch.setenv("TPU_WORKER_ID", "0")
142+
assert envs.TPU_WORKER_ID == "0"
143+
144+
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "grpc")
145+
assert envs.TPU_MULTIHOST_BACKEND == "grpc"
146+
147+
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "xla")
148+
assert envs.TPU_MULTIHOST_BACKEND == "xla"
149+
150+
151+
def test_disaggregated_serving_env_vars(monkeypatch: pytest.MonkeyPatch):
152+
monkeypatch.setenv("PREFILL_SLICES", "0,1,2,3")
153+
assert envs.PREFILL_SLICES == "0,1,2,3"
154+
155+
monkeypatch.setenv("DECODE_SLICES", "4,5,6,7")
156+
assert envs.DECODE_SLICES == "4,5,6,7"
157+
158+
159+
def test_model_impl_type_default(monkeypatch: pytest.MonkeyPatch):
160+
monkeypatch.delenv("MODEL_IMPL_TYPE", raising=False)
161+
assert envs.MODEL_IMPL_TYPE == "flax_nnx"
162+
163+
164+
def test_cache_preserves_values_across_env_changes(
165+
monkeypatch: pytest.MonkeyPatch):
166+
monkeypatch.setenv("JAX_PLATFORMS", "tpu")
167+
168+
enable_envs_cache()
169+
170+
assert envs.JAX_PLATFORMS == "tpu"
171+
172+
# Change environment variable
173+
monkeypatch.setenv("JAX_PLATFORMS", "cpu")
174+
175+
# Cached value should still be "tpu"
176+
assert envs.JAX_PLATFORMS == "tpu"
177+
178+
# Reset envs.__getattr__ back to non-cached version
179+
envs.__getattr__ = envs.__getattr__.__wrapped__
180+
181+
# Now it should reflect the new value
182+
assert envs.JAX_PLATFORMS == "cpu"

0 commit comments

Comments
 (0)