Skip to content

Commit 0a6113b

Browse files
Kacper-PietkunmichalkuligowskiCopilot
authored
Add tests for custom operator implementation correctness (#457)
I added tests for custom ops defined in `vllm_gaudi/ops`: - For the tests of ops that are not using cuda kernels - native ops and hpu ops are triggered for the same input and their outputs are compared - For others tests that are using cuda kernels (so cannot be called with vllm-gaudi plugin) I created separate directory to store some predefined small tensors - weights, inputs and outputs. These tensors are too big to hardcode them in tests, however their sizes were adjusted, so all of them weight less than 3MB in total. Tensors are stored in a .safetensors format. Such tests run hpu ops with loaded inputs and weights and compare their outputs with the loaded outputs. --------- Signed-off-by: Kacper Pietkun <[email protected]> Signed-off-by: Kacper Pietkun <[email protected]> Co-authored-by: Michał Kuligowski <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent 1668361 commit 0a6113b

19 files changed

+1057
-2
lines changed
21.5 KB
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
86.3 KB
Binary file not shown.
1.5 MB
Binary file not shown.
818 KB
Binary file not shown.
4.47 KB
Binary file not shown.
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import torch
5+
import habana_frameworks.torch as htorch
6+
from utils import get_data_path, create_row_parallel_linear
7+
from vllm_gaudi.ops.hpu_awq import AWQHPULinearMethod, AWQHPUConfig
8+
from vllm_gaudi.utils import HPUCompileConfig
9+
from safetensors import safe_open
10+
11+
12+
def test_awq_linear_method(dist_init):
13+
config = {"bits": 4, "group_size": 128, "zero_point": True}
14+
oot_quant_config = AWQHPUConfig.from_config(config)
15+
16+
# Prepare linear layer with oot AWQHPULinearMethod
17+
oot_op = create_row_parallel_linear(input_size=256, output_size=128, quant_config=oot_quant_config).to("hpu")
18+
assert isinstance(oot_op.quant_method, AWQHPULinearMethod)
19+
20+
# qweight, qzeros, scales were extracted from first RowParallelLinear of TheBloke/Llama-2-7B-Chat-AWQ
21+
# (with adjusted shape, to make tensors smaller)
22+
with safe_open(get_data_path("data/awq/linear.safetensors"), framework="pt", device="hpu") as f:
23+
oot_op.qweight.copy_(f.get_tensor("qweight"))
24+
oot_op.qzeros.copy_(f.get_tensor("qzeros"))
25+
oot_op.scales.copy_(f.get_tensor("scales"))
26+
oot_op.quant_method.process_weights_after_loading(oot_op)
27+
28+
if not htorch.utils.internal.is_lazy():
29+
compile_config = HPUCompileConfig()
30+
oot_op = torch.compile(oot_op, **compile_config.get_compile_args())
31+
32+
# Input and expected output
33+
# Output tensor holds the data that was returned by cuda implementation of AWQLinearMethod for given input
34+
# (AWQLinearMethod was triggered offline with the same input as below to get the ref_output)
35+
with safe_open(get_data_path("data/awq/linear.safetensors"), framework="pt", device="hpu") as f:
36+
input = f.get_tensor("input").to(torch.bfloat16)
37+
ref_output = f.get_tensor("ref_output").to(torch.bfloat16)
38+
39+
# Execute layer
40+
out = oot_op(input)
41+
42+
# Check correctness
43+
torch.testing.assert_close(ref_output, out, atol=1e-3, rtol=1e-3)
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import torch
5+
import habana_frameworks.torch as htorch
6+
from utils import get_data_path, create_row_parallel_linear, create_fused_moe
7+
from unittest.mock import MagicMock
8+
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import CompressedTensorsConfig
9+
from vllm_gaudi.ops.hpu_compressed_tensors import (HPUCompressedTensorsLinearMethod, HPUCompressedTensorsW8A8Fp8,
10+
HPUCompressedTensorsWNA16, HPUCompressedTensorsWNA16MoEMethod)
11+
from vllm_gaudi.utils import HPUCompileConfig
12+
from vllm.forward_context import override_forward_context
13+
from safetensors import safe_open
14+
15+
16+
def test_compressed_tensors_linear_method_w8a8fp8(dist_init):
17+
config = {
18+
'config_groups': {
19+
'group_0': {
20+
'input_activations': {
21+
'block_structure': None,
22+
'dynamic': True,
23+
'group_size': None,
24+
'num_bits': 8,
25+
'observer': 'memoryless',
26+
'observer_kwargs': {},
27+
'strategy': 'token',
28+
'symmetric': True,
29+
'type': 'float'
30+
},
31+
'output_activations': None,
32+
'targets': ['Linear'],
33+
'weights': {
34+
'block_structure': None,
35+
'dynamic': False,
36+
'group_size': None,
37+
'num_bits': 8,
38+
'observer': 'minmax',
39+
'observer_kwargs': {},
40+
'strategy': 'channel',
41+
'symmetric': True,
42+
'type': 'float'
43+
}
44+
}
45+
},
46+
'format': 'naive-quantized',
47+
'global_compression_ratio': 1.239290831149584,
48+
'ignore': [],
49+
'kv_cache_scheme': None,
50+
'quant_method': 'compressed-tensors',
51+
'quantization_status': 'frozen'
52+
}
53+
oot_quant_config = CompressedTensorsConfig.from_config(config)
54+
55+
# Prepare linear layer with oot CompressedTensorsLinearMethod
56+
# with HPUCompressedTensorsW8A8Fp8 scheme
57+
oot_op = create_row_parallel_linear(input_size=256, output_size=256, quant_config=oot_quant_config).to("hpu")
58+
assert isinstance(oot_op.quant_method, HPUCompressedTensorsLinearMethod)
59+
assert isinstance(oot_op.scheme, HPUCompressedTensorsW8A8Fp8)
60+
61+
# Weight and weight_scale_inv were extracted from first RowParallelLinear
62+
# layer of RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8-dynamic
63+
# (with adjusted shapes, to make tensors smaller)
64+
with safe_open(get_data_path("data/compressed_tensors/linear_w8a8fp8.safetensors"), framework="pt",
65+
device="hpu") as f:
66+
oot_op.weight.copy_(f.get_tensor("weight"))
67+
oot_op.weight_scale.copy_(f.get_tensor("weight_scale"))
68+
oot_op.quant_method.process_weights_after_loading(oot_op)
69+
70+
if not htorch.utils.internal.is_lazy():
71+
compile_config = HPUCompileConfig()
72+
oot_op = torch.compile(oot_op, **compile_config.get_compile_args())
73+
74+
# Input and expected output
75+
# Output tensor holds data that was returned by cuda impl of CompressedTensorsLinearMethod for given input
76+
# (CompressedTensorsLinearMethod was triggered offline with the same input as below to get the ref_output)
77+
with safe_open(get_data_path("data/compressed_tensors/linear_w8a8fp8.safetensors"), framework="pt",
78+
device="hpu") as f:
79+
input = f.get_tensor("input")
80+
ref_output = f.get_tensor("ref_output")
81+
82+
# Execute layer
83+
out = oot_op(input)
84+
85+
# Check correctness
86+
torch.testing.assert_close(ref_output, out, atol=1e-3, rtol=1e-3)
87+
88+
89+
def test_compressed_tensors_linear_method_wna16(dist_init):
90+
config = {
91+
'config_groups': {
92+
'group_0': {
93+
'input_activations': None,
94+
'output_activations': None,
95+
'targets': ['Linear'],
96+
'weights': {
97+
'actorder': 'weight',
98+
'block_structure': None,
99+
'dynamic': False,
100+
'group_size': 128,
101+
'num_bits': 4,
102+
'observer': 'minmax',
103+
'observer_kwargs': {},
104+
'strategy': 'group',
105+
'symmetric': False,
106+
'type': 'int'
107+
}
108+
}
109+
},
110+
'format': 'pack-quantized',
111+
'global_compression_ratio': None,
112+
'ignore': [],
113+
'kv_cache_scheme': None,
114+
'quant_method': 'compressed-tensors',
115+
'quantization_status': 'compressed'
116+
}
117+
oot_quant_config = CompressedTensorsConfig.from_config(config)
118+
119+
# Prepare linear layer with oot CompressedTensorsLinearMethod
120+
# with HPUCompressedTensorsWNA16 scheme
121+
oot_op = create_row_parallel_linear(input_size=256, output_size=256, quant_config=oot_quant_config).to("hpu")
122+
assert isinstance(oot_op.quant_method, HPUCompressedTensorsLinearMethod)
123+
assert isinstance(oot_op.scheme, HPUCompressedTensorsWNA16)
124+
125+
# Weights were extracted from first RowParallelLinear layer of RedHatAI/Qwen3-8B-quantized.w4a16
126+
# (with adjusted shapes, to make tensors smaller)
127+
with safe_open(get_data_path("data/compressed_tensors/linear_wna16.safetensors"), framework="pt",
128+
device="hpu") as f:
129+
oot_op.weight_packed.copy_(f.get_tensor("weight_packed"))
130+
oot_op.weight_scale.copy_(f.get_tensor("weight_scale"))
131+
oot_op.weight_zero_point.copy_(f.get_tensor("weight_zero_point"))
132+
oot_op.weight_shape.data = torch.tensor([256, 256], device='hpu:0')
133+
oot_op.quant_method.process_weights_after_loading(oot_op)
134+
135+
if not htorch.utils.internal.is_lazy():
136+
compile_config = HPUCompileConfig()
137+
oot_op = torch.compile(oot_op, **compile_config.get_compile_args())
138+
139+
# Input and expected output
140+
# Output tensor holds data that was returned by cuda impl of CompressedTensorsLinearMethod for given input
141+
# (CompressedTensorsLinearMethod was triggered offline with the same input as below to get the ref_output)
142+
with safe_open(get_data_path("data/compressed_tensors/linear_wna16.safetensors"), framework="pt",
143+
device="hpu") as f:
144+
input = f.get_tensor("input")
145+
ref_output = f.get_tensor("ref_output")
146+
147+
# Execute layer
148+
out = oot_op(input)
149+
150+
# Check correctness
151+
torch.testing.assert_close(ref_output, out, atol=1e-3, rtol=1e-3)
152+
153+
154+
def test_compressed_tensors_wna16_moe_method(dist_init):
155+
config = {
156+
'config_groups': {
157+
'group_0': {
158+
'input_activations': None,
159+
'output_activations': None,
160+
'targets': ['Linear'],
161+
'weights': {
162+
'actorder': 'weight',
163+
'block_structure': None,
164+
'dynamic': False,
165+
'group_size': 128,
166+
'num_bits': 4,
167+
'observer': 'minmax',
168+
'observer_kwargs': {},
169+
'strategy': 'group',
170+
'symmetric': True,
171+
'type': 'int'
172+
}
173+
}
174+
},
175+
'format': 'pack-quantized',
176+
'global_compression_ratio': None,
177+
'ignore': [],
178+
'kv_cache_scheme': None,
179+
'quant_method': 'compressed-tensors',
180+
'quantization_status': 'compressed'
181+
}
182+
oot_quant_config = CompressedTensorsConfig.from_config(config)
183+
184+
# Prepare FusedMoE layer with oot HPUCompressedTensorsWNA16MoEMethod
185+
oot_op = create_fused_moe(oot_quant_config).to("hpu")
186+
assert isinstance(oot_op.quant_method, HPUCompressedTensorsWNA16MoEMethod)
187+
188+
# Weights were extracted from first FusedMoE layer of RedHatAI/Qwen3-30B-A3B-quantized.w4a16
189+
# (with adjusted shapes, to make tensors smaller)
190+
with safe_open(get_data_path("data/compressed_tensors/moe_wna16.safetensors"), framework="pt", device="hpu") as f:
191+
w2_weight_packed = f.get_tensor("w2_weight_packed")
192+
w2_weight_packed = torch.swapaxes(w2_weight_packed, 0, 1).repeat(128, 1, 1)
193+
oot_op.w2_weight_packed.copy_(w2_weight_packed)
194+
195+
w13_weight_packed = f.get_tensor("w13_weight_packed")
196+
w13_weight_packed = torch.swapaxes(w13_weight_packed, 0, 1).repeat(128, 1, 1)
197+
oot_op.w13_weight_packed.copy_(w13_weight_packed)
198+
199+
w2_weight_scale = f.get_tensor("w2_weight_scale")
200+
w2_weight_scale = torch.swapaxes(w2_weight_scale, 0, 1).repeat(128, 1, 1)
201+
oot_op.w2_weight_scale.copy_(w2_weight_scale)
202+
203+
w13_weight_scale = f.get_tensor("w13_weight_scale")
204+
w13_weight_scale = torch.swapaxes(w13_weight_scale, 0, 1).repeat(128, 1, 1)
205+
oot_op.w13_weight_scale.copy_(w13_weight_scale)
206+
207+
w2_weight_shape = torch.tensor([512, 256], dtype=torch.bfloat16, device="hpu")
208+
oot_op.w2_weight_shape.copy_(w2_weight_shape.repeat(128, 1))
209+
210+
w13_weight_shape = torch.tensor([256, 512], dtype=torch.bfloat16, device="hpu")
211+
oot_op.w13_weight_shape.copy_(w13_weight_shape.repeat(128, 1))
212+
213+
oot_op.quant_method.process_weights_after_loading(oot_op)
214+
215+
if not htorch.utils.internal.is_lazy():
216+
compile_config = HPUCompileConfig()
217+
oot_op = torch.compile(oot_op, **compile_config.get_compile_args())
218+
219+
# Input and expected output
220+
# Output tensor holds data that was returned by cuda impl of CompressedTensorsWNA16MarlinMoEMethod for given input
221+
# (CompressedTensorsWNA16MarlinMoEMethod was triggered offline with the same input as below to get the ref_output)
222+
with safe_open(get_data_path("data/compressed_tensors/moe_wna16.safetensors"), framework="pt", device="hpu") as f:
223+
hidden_states = f.get_tensor("hidden_states")
224+
router_logits = f.get_tensor("router_logits")
225+
ref_output = f.get_tensor("ref_output")
226+
227+
# Execute layer
228+
mock_ctx = MagicMock(spec=["dp_metadata"])
229+
mock_ctx.dp_metadata = None
230+
with override_forward_context(mock_ctx):
231+
out = oot_op.forward_impl(hidden_states, router_logits)
232+
233+
# Check correctness
234+
torch.testing.assert_close(ref_output, out, atol=1e-4, rtol=1e-4)

0 commit comments

Comments
 (0)