Skip to content

Commit 00cb97c

Browse files
kyuyeunkAahilA
authored andcommitted
[Torchax] Add initial support for loading mxfp4 (vllm-project#1080)
Signed-off-by: Kyuyeun Kim <[email protected]>
1 parent e703c97 commit 00cb97c

File tree

6 files changed

+460
-13
lines changed

6 files changed

+460
-13
lines changed

tests/layers/vllm/test_mxfp4.py

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
import tempfile
2+
3+
import jax
4+
import jax.numpy as jnp
5+
import pytest
6+
import torch
7+
import torchax
8+
import utils as test_utils
9+
from jax.sharding import NamedSharding, PartitionSpec
10+
from torchax.interop import torch_view
11+
from torchax.ops.mappings import j2t, t2j
12+
from vllm.config import set_current_vllm_config
13+
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
14+
init_distributed_environment)
15+
from vllm.engine.arg_utils import EngineArgs
16+
from vllm.forward_context import set_forward_context
17+
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
18+
19+
from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
20+
from tpu_inference.layers.vllm.quantization.mxfp4 import (VllmMxfp4Config,
21+
VllmMxfp4MoEMethod)
22+
23+
P = PartitionSpec
24+
MODELS = ["openai/gpt-oss-20b"]
25+
MXFP4_BLOCK_SIZE = 32
26+
27+
28+
def quantize_to_mxfp4(weight: torch.tensor):
29+
# Utilize JAX because native support for e2m1 makes it easier to work with.
30+
weight = t2j(weight)
31+
e2m1_finfo = jnp.finfo(jnp.float4_e2m1fn)
32+
dtype_min = float(e2m1_finfo.min)
33+
dtype_max = float(e2m1_finfo.max)
34+
35+
# Do a subchannel quantization where block size is 32.
36+
weight_shape = weight.shape
37+
weight_block = weight.reshape(weight_shape[:-1] + (-1, MXFP4_BLOCK_SIZE))
38+
abs_max = jnp.max(jnp.abs(weight_block), axis=-1, keepdims=True)
39+
scale = abs_max / dtype_max
40+
41+
weight_q = jnp.clip(weight_block / scale, dtype_min, dtype_max)
42+
weight_q = weight_q.astype(jnp.float4_e2m1fn).reshape(weight_shape[:-1] +
43+
(-1, 2))
44+
weight_packed = jax.lax.bitcast_convert_type(weight_q, jnp.uint8)
45+
46+
# We convert scale into e8m0 manually because there is no hardware support.
47+
e8m0_finfo = jnp.finfo(jnp.float8_e8m0fnu)
48+
_, scale_exp = jnp.frexp(scale.squeeze(axis=-1))
49+
# Subtract by one sinced e8m0 has no decimal
50+
scale_exp -= 1
51+
scale_exp = (scale_exp - e8m0_finfo.minexp).astype(jnp.uint8)
52+
53+
return j2t(weight_packed), j2t(scale_exp)
54+
55+
56+
@pytest.fixture(autouse=True)
57+
def setup_environment():
58+
# This is a fake config used for init dist env.
59+
# RowParallelLinear needs dist env to be initialized.
60+
engine_args = EngineArgs(
61+
model=MODELS[0],
62+
max_model_len=64,
63+
max_num_batched_tokens=64,
64+
max_num_seqs=4,
65+
load_format='dummy',
66+
)
67+
68+
vllm_config = engine_args.create_engine_config()
69+
70+
with set_current_vllm_config(vllm_config):
71+
temp_file = tempfile.mkstemp()[1]
72+
init_distributed_environment(
73+
1,
74+
0,
75+
local_rank=0,
76+
distributed_init_method=f"file://{temp_file}",
77+
backend="gloo")
78+
ensure_model_parallel_initialized(1, 1)
79+
80+
81+
@pytest.mark.parametrize("model", MODELS)
82+
@pytest.mark.parametrize("mesh", [
83+
test_utils.get_spmd_mesh(1),
84+
test_utils.get_spmd_mesh(jax.local_device_count())
85+
])
86+
def test_quant_override(model, mesh):
87+
88+
engine_args = EngineArgs(
89+
model=model,
90+
max_model_len=64,
91+
max_num_batched_tokens=64,
92+
max_num_seqs=4,
93+
load_format='dummy',
94+
)
95+
vllm_config = engine_args.create_engine_config()
96+
vllm_config.model_config.dtype = torch.bfloat16
97+
98+
quant_config = get_tpu_quantization_config(vllm_config, mesh)
99+
assert isinstance(quant_config, VllmMxfp4Config)
100+
assert quant_config.vllm_config == vllm_config
101+
assert quant_config.mesh == mesh
102+
103+
104+
@pytest.mark.parametrize("mesh", [
105+
test_utils.get_spmd_mesh(1),
106+
test_utils.get_spmd_mesh(jax.local_device_count())
107+
])
108+
@pytest.mark.parametrize("num_tokens", [8])
109+
@pytest.mark.parametrize("intermediate_size", [1024])
110+
@pytest.mark.parametrize("hidden_size", [128])
111+
@pytest.mark.parametrize("num_experts", [8])
112+
@pytest.mark.parametrize("topk", [2])
113+
def test_fused_moe_bias(mesh, num_tokens, intermediate_size, hidden_size,
114+
num_experts, topk):
115+
torch.manual_seed(42)
116+
dtype = torch.bfloat16
117+
118+
a = torch.randn((num_tokens, hidden_size), dtype=dtype) / 10
119+
w1 = torch.randn(
120+
(num_experts, 2 * intermediate_size, hidden_size), dtype=dtype) / 10
121+
w2 = torch.randn(
122+
(num_experts, hidden_size, intermediate_size), dtype=dtype) / 10
123+
w1_weight, w1_weight_scale = quantize_to_mxfp4(w1)
124+
w2_weight, w2_weight_scale = quantize_to_mxfp4(w2)
125+
126+
print(f'kky {w1_weight.shape=} {w1_weight_scale.shape=}')
127+
128+
w1_bias = torch.randn(
129+
(num_experts, 2 * intermediate_size), dtype=dtype) / 10
130+
w2_bias = torch.randn((num_experts, hidden_size), dtype=dtype) / 10
131+
score = torch.randn((num_tokens, num_experts), dtype=dtype)
132+
133+
engine_args = EngineArgs(
134+
model=MODELS[0],
135+
max_model_len=64,
136+
max_num_batched_tokens=64,
137+
max_num_seqs=4,
138+
load_format='dummy',
139+
)
140+
vllm_config = engine_args.create_engine_config()
141+
vllm_config.model_config.dtype = dtype
142+
143+
quant_config = get_tpu_quantization_config(vllm_config, mesh)
144+
with set_current_vllm_config(vllm_config):
145+
vllm_fused_moe = FusedMoE(
146+
num_experts=num_experts,
147+
top_k=topk,
148+
hidden_size=hidden_size,
149+
intermediate_size=intermediate_size,
150+
reduce_results=False,
151+
renormalize=False,
152+
tp_size=1,
153+
dp_size=1,
154+
quant_config=quant_config,
155+
has_bias=True,
156+
)
157+
vllm_fused_moe.w13_weight.data = w1_weight
158+
vllm_fused_moe.w2_weight.data = w2_weight
159+
vllm_fused_moe.w13_weight_scale.data = w1_weight_scale
160+
vllm_fused_moe.w2_weight_scale.data = w2_weight_scale
161+
vllm_fused_moe.w13_bias.data = w1_bias
162+
vllm_fused_moe.w2_bias.data = w2_bias
163+
164+
with torchax.default_env(), set_forward_context(None, vllm_config):
165+
assert isinstance(vllm_fused_moe.quant_method, VllmMxfp4MoEMethod)
166+
167+
jax_a = a.to('jax')
168+
jax_a.apply_jax_(jax.device_put, NamedSharding(mesh, P(None, None)))
169+
score = torch_view(t2j(score))
170+
score.apply_jax_(jax.device_put, NamedSharding(mesh, P(None, None)))
171+
172+
vllm_fused_moe.quant_method.process_weights_after_loading(
173+
vllm_fused_moe)
174+
175+
# Because we are dequantizing mxfp4 weights for now, we verify if
176+
# dequantized weights matches with the original weights.
177+
# Due to NaN, comparing two values are difficult. Therefore, we utilize
178+
# nanmean instead.
179+
torch.testing.assert_close(torch.nanmean(vllm_fused_moe.w13_weight),
180+
torch.nanmean(w1),
181+
check_device=False,
182+
equal_nan=True,
183+
rtol=0.2,
184+
atol=0.1)
185+
torch.testing.assert_close(torch.nanmean(vllm_fused_moe.w2_weight),
186+
torch.nanmean(w2),
187+
check_device=False,
188+
equal_nan=True,
189+
rtol=0.2,
190+
atol=0.1)
191+
192+
vllm_fused_moe(jax_a, score)

tpu_inference/layers/vllm/quantization/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
1010
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import \
1111
VllmCompressedTensorsConfig # noqa: E501
12+
from tpu_inference.layers.vllm.quantization.mxfp4 import VllmMxfp4Config
1213
from tpu_inference.layers.vllm.quantization.unquantized import \
1314
VllmUnquantizedConfig
1415

@@ -21,6 +22,7 @@ def get_tpu_quantization_config(vllm_config: VllmConfig,
2122
None: VllmUnquantizedConfig,
2223
"compressed-tensors": VllmCompressedTensorsConfig,
2324
"awq": VllmAWQConfig,
25+
"mxfp4": VllmMxfp4Config,
2426
}
2527
if model_config.quantization not in method_to_config:
2628
raise NotImplementedError(
@@ -30,6 +32,7 @@ def get_tpu_quantization_config(vllm_config: VllmConfig,
3032
assert issubclass(quant_config, JaxCommonConfig)
3133
quant_config.set_configs(vllm_config, mesh)
3234

33-
model_config.quantization = quant_config.get_name()
35+
# TODO(kyuyeunk): Create more programmatic way to handle this.
36+
model_config.quantization = "tpu-" + quant_config.get_name()
3437
return VllmConfig.get_quantization_config(model_config,
3538
vllm_config.load_config)

tpu_inference/layers/vllm/quantization/awq.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,9 @@
2929
logger = init_logger(__name__)
3030

3131

32-
@register_quantization_config("jax-awq")
32+
@register_quantization_config("tpu-awq")
3333
class VllmAWQConfig(AWQConfig, JaxCommonConfig):
3434

35-
@classmethod
36-
def get_name(cls) -> str:
37-
return "jax-awq"
38-
3935
def get_supported_act_dtypes(self) -> list[torch.dtype]:
4036
# NOTE: AWQ checkpoint was quantized with float16. But on TPUs, using
4137
# bfloat16 is signifcantly preferred over foat16. This might lead to

tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,9 @@
3030
logger = init_logger(__name__)
3131

3232

33-
@register_quantization_config("jax-compressed-tensors")
33+
@register_quantization_config("tpu-compressed-tensors")
3434
class VllmCompressedTensorsConfig(CompressedTensorsConfig, JaxCommonConfig):
3535

36-
@classmethod
37-
def get_name(cls) -> str:
38-
return "jax-compressed-tensors"
39-
4036
def get_scheme(self,
4137
layer: torch.nn.Module,
4238
layer_name: Optional[str] = None

0 commit comments

Comments
 (0)