|
| 1 | +import tempfile |
| 2 | + |
| 3 | +import jax |
| 4 | +import jax.numpy as jnp |
| 5 | +import pytest |
| 6 | +import torch |
| 7 | +import torchax |
| 8 | +import utils as test_utils |
| 9 | +from jax.sharding import NamedSharding, PartitionSpec |
| 10 | +from torchax.interop import torch_view |
| 11 | +from torchax.ops.mappings import j2t, t2j |
| 12 | +from vllm.config import set_current_vllm_config |
| 13 | +from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, |
| 14 | + init_distributed_environment) |
| 15 | +from vllm.engine.arg_utils import EngineArgs |
| 16 | +from vllm.forward_context import set_forward_context |
| 17 | +from vllm.model_executor.layers.fused_moe.layer import FusedMoE |
| 18 | + |
| 19 | +from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config |
| 20 | +from tpu_inference.layers.vllm.quantization.mxfp4 import (VllmMxfp4Config, |
| 21 | + VllmMxfp4MoEMethod) |
| 22 | + |
| 23 | +P = PartitionSpec |
| 24 | +MODELS = ["openai/gpt-oss-20b"] |
| 25 | +MXFP4_BLOCK_SIZE = 32 |
| 26 | + |
| 27 | + |
| 28 | +def quantize_to_mxfp4(weight: torch.tensor): |
| 29 | + # Utilize JAX because native support for e2m1 makes it easier to work with. |
| 30 | + weight = t2j(weight) |
| 31 | + e2m1_finfo = jnp.finfo(jnp.float4_e2m1fn) |
| 32 | + dtype_min = float(e2m1_finfo.min) |
| 33 | + dtype_max = float(e2m1_finfo.max) |
| 34 | + |
| 35 | + # Do a subchannel quantization where block size is 32. |
| 36 | + weight_shape = weight.shape |
| 37 | + weight_block = weight.reshape(weight_shape[:-1] + (-1, MXFP4_BLOCK_SIZE)) |
| 38 | + abs_max = jnp.max(jnp.abs(weight_block), axis=-1, keepdims=True) |
| 39 | + scale = abs_max / dtype_max |
| 40 | + |
| 41 | + weight_q = jnp.clip(weight_block / scale, dtype_min, dtype_max) |
| 42 | + weight_q = weight_q.astype(jnp.float4_e2m1fn).reshape(weight_shape[:-1] + |
| 43 | + (-1, 2)) |
| 44 | + weight_packed = jax.lax.bitcast_convert_type(weight_q, jnp.uint8) |
| 45 | + |
| 46 | + # We convert scale into e8m0 manually because there is no hardware support. |
| 47 | + e8m0_finfo = jnp.finfo(jnp.float8_e8m0fnu) |
| 48 | + _, scale_exp = jnp.frexp(scale.squeeze(axis=-1)) |
| 49 | + # Subtract by one sinced e8m0 has no decimal |
| 50 | + scale_exp -= 1 |
| 51 | + scale_exp = (scale_exp - e8m0_finfo.minexp).astype(jnp.uint8) |
| 52 | + |
| 53 | + return j2t(weight_packed), j2t(scale_exp) |
| 54 | + |
| 55 | + |
| 56 | +@pytest.fixture(autouse=True) |
| 57 | +def setup_environment(): |
| 58 | + # This is a fake config used for init dist env. |
| 59 | + # RowParallelLinear needs dist env to be initialized. |
| 60 | + engine_args = EngineArgs( |
| 61 | + model=MODELS[0], |
| 62 | + max_model_len=64, |
| 63 | + max_num_batched_tokens=64, |
| 64 | + max_num_seqs=4, |
| 65 | + load_format='dummy', |
| 66 | + ) |
| 67 | + |
| 68 | + vllm_config = engine_args.create_engine_config() |
| 69 | + |
| 70 | + with set_current_vllm_config(vllm_config): |
| 71 | + temp_file = tempfile.mkstemp()[1] |
| 72 | + init_distributed_environment( |
| 73 | + 1, |
| 74 | + 0, |
| 75 | + local_rank=0, |
| 76 | + distributed_init_method=f"file://{temp_file}", |
| 77 | + backend="gloo") |
| 78 | + ensure_model_parallel_initialized(1, 1) |
| 79 | + |
| 80 | + |
| 81 | +@pytest.mark.parametrize("model", MODELS) |
| 82 | +@pytest.mark.parametrize("mesh", [ |
| 83 | + test_utils.get_spmd_mesh(1), |
| 84 | + test_utils.get_spmd_mesh(jax.local_device_count()) |
| 85 | +]) |
| 86 | +def test_quant_override(model, mesh): |
| 87 | + |
| 88 | + engine_args = EngineArgs( |
| 89 | + model=model, |
| 90 | + max_model_len=64, |
| 91 | + max_num_batched_tokens=64, |
| 92 | + max_num_seqs=4, |
| 93 | + load_format='dummy', |
| 94 | + ) |
| 95 | + vllm_config = engine_args.create_engine_config() |
| 96 | + vllm_config.model_config.dtype = torch.bfloat16 |
| 97 | + |
| 98 | + quant_config = get_tpu_quantization_config(vllm_config, mesh) |
| 99 | + assert isinstance(quant_config, VllmMxfp4Config) |
| 100 | + assert quant_config.vllm_config == vllm_config |
| 101 | + assert quant_config.mesh == mesh |
| 102 | + |
| 103 | + |
| 104 | +@pytest.mark.parametrize("mesh", [ |
| 105 | + test_utils.get_spmd_mesh(1), |
| 106 | + test_utils.get_spmd_mesh(jax.local_device_count()) |
| 107 | +]) |
| 108 | +@pytest.mark.parametrize("num_tokens", [8]) |
| 109 | +@pytest.mark.parametrize("intermediate_size", [1024]) |
| 110 | +@pytest.mark.parametrize("hidden_size", [128]) |
| 111 | +@pytest.mark.parametrize("num_experts", [8]) |
| 112 | +@pytest.mark.parametrize("topk", [2]) |
| 113 | +def test_fused_moe_bias(mesh, num_tokens, intermediate_size, hidden_size, |
| 114 | + num_experts, topk): |
| 115 | + torch.manual_seed(42) |
| 116 | + dtype = torch.bfloat16 |
| 117 | + |
| 118 | + a = torch.randn((num_tokens, hidden_size), dtype=dtype) / 10 |
| 119 | + w1 = torch.randn( |
| 120 | + (num_experts, 2 * intermediate_size, hidden_size), dtype=dtype) / 10 |
| 121 | + w2 = torch.randn( |
| 122 | + (num_experts, hidden_size, intermediate_size), dtype=dtype) / 10 |
| 123 | + w1_weight, w1_weight_scale = quantize_to_mxfp4(w1) |
| 124 | + w2_weight, w2_weight_scale = quantize_to_mxfp4(w2) |
| 125 | + |
| 126 | + print(f'kky {w1_weight.shape=} {w1_weight_scale.shape=}') |
| 127 | + |
| 128 | + w1_bias = torch.randn( |
| 129 | + (num_experts, 2 * intermediate_size), dtype=dtype) / 10 |
| 130 | + w2_bias = torch.randn((num_experts, hidden_size), dtype=dtype) / 10 |
| 131 | + score = torch.randn((num_tokens, num_experts), dtype=dtype) |
| 132 | + |
| 133 | + engine_args = EngineArgs( |
| 134 | + model=MODELS[0], |
| 135 | + max_model_len=64, |
| 136 | + max_num_batched_tokens=64, |
| 137 | + max_num_seqs=4, |
| 138 | + load_format='dummy', |
| 139 | + ) |
| 140 | + vllm_config = engine_args.create_engine_config() |
| 141 | + vllm_config.model_config.dtype = dtype |
| 142 | + |
| 143 | + quant_config = get_tpu_quantization_config(vllm_config, mesh) |
| 144 | + with set_current_vllm_config(vllm_config): |
| 145 | + vllm_fused_moe = FusedMoE( |
| 146 | + num_experts=num_experts, |
| 147 | + top_k=topk, |
| 148 | + hidden_size=hidden_size, |
| 149 | + intermediate_size=intermediate_size, |
| 150 | + reduce_results=False, |
| 151 | + renormalize=False, |
| 152 | + tp_size=1, |
| 153 | + dp_size=1, |
| 154 | + quant_config=quant_config, |
| 155 | + has_bias=True, |
| 156 | + ) |
| 157 | + vllm_fused_moe.w13_weight.data = w1_weight |
| 158 | + vllm_fused_moe.w2_weight.data = w2_weight |
| 159 | + vllm_fused_moe.w13_weight_scale.data = w1_weight_scale |
| 160 | + vllm_fused_moe.w2_weight_scale.data = w2_weight_scale |
| 161 | + vllm_fused_moe.w13_bias.data = w1_bias |
| 162 | + vllm_fused_moe.w2_bias.data = w2_bias |
| 163 | + |
| 164 | + with torchax.default_env(), set_forward_context(None, vllm_config): |
| 165 | + assert isinstance(vllm_fused_moe.quant_method, VllmMxfp4MoEMethod) |
| 166 | + |
| 167 | + jax_a = a.to('jax') |
| 168 | + jax_a.apply_jax_(jax.device_put, NamedSharding(mesh, P(None, None))) |
| 169 | + score = torch_view(t2j(score)) |
| 170 | + score.apply_jax_(jax.device_put, NamedSharding(mesh, P(None, None))) |
| 171 | + |
| 172 | + vllm_fused_moe.quant_method.process_weights_after_loading( |
| 173 | + vllm_fused_moe) |
| 174 | + |
| 175 | + # Because we are dequantizing mxfp4 weights for now, we verify if |
| 176 | + # dequantized weights matches with the original weights. |
| 177 | + # Due to NaN, comparing two values are difficult. Therefore, we utilize |
| 178 | + # nanmean instead. |
| 179 | + torch.testing.assert_close(torch.nanmean(vllm_fused_moe.w13_weight), |
| 180 | + torch.nanmean(w1), |
| 181 | + check_device=False, |
| 182 | + equal_nan=True, |
| 183 | + rtol=0.2, |
| 184 | + atol=0.1) |
| 185 | + torch.testing.assert_close(torch.nanmean(vllm_fused_moe.w2_weight), |
| 186 | + torch.nanmean(w2), |
| 187 | + check_device=False, |
| 188 | + equal_nan=True, |
| 189 | + rtol=0.2, |
| 190 | + atol=0.1) |
| 191 | + |
| 192 | + vllm_fused_moe(jax_a, score) |
0 commit comments