Commit 03e694a
[AWQ] Generalize AWQ quantization (#1961)
### Summary
To allow for arbitrary heterogeneous quantization schemes, this PR
switches several helpers from AutoAWQ to the observer and QDQ logic. AWQ
no longer constrains that the quantization config needs to have the same
settings for group_size, symmetric, and num_bits for each config_group.
Resolves #1657
Prerequisites:
* vllm-project/compressed-tensors#519
### Test plan
- [x] When running `llm-compressor/examples/awq/llama_example.py` with
this (with `duo_scaling="both"`) and logging the best configuration of
`(ratio, duo_scaling)`, I see a good mix of Falses and Trues. i.e. a
good percentage of best_scales were found with duo_scaling=False and a
good percentage were found with `duo_scaling=True`. Generated model
output looks good.
- [x] When using `awq_one_shot.py` (pasted below), Wikitext PPL is
consistent for w4a16 and w4a16_asym on this branch when compared to
main, and better than what was reported in a [previous AWQ
PR](#1444 (comment)),
but those might have been differently configured. For W4A16_ASYM, the
results are both 13.41 for main and this branch. This is what we've been
historically using to test regressions.
|Scheme|Wikitext PPL RTN|AWQ main|AWQ this branch|
|-----------:|---------------------:|----------|-----:|
|W4A16| 13.784 |13.477| 13.426|
|W4A16_ASYM | 13.606 | 13.346 | 13.377|
- [x] I see a small regression in recovery when running `CADENCE=weekly
TEST_DATA_FILE=~/projects/llm-compressor/tests/lmeval/configs/w4a16_awq_sym.yaml
pytest -s ~/projects/llm-compressor/tests/lmeval/test_lmeval.py` on this
branch, which causes the test to fail. This persists even when using
`pseudo_quantize_tensor` instead of `call_observer`/`forward_quantize`,
as shown in [this
diff](https://github.com/vllm-project/llm-compressor/compare/kylesayrs/awq-generalize-quant...bdellabe/awq-generalize-quant?expand=1).
I get the same result in this diff, so at least that means quantization
logic in CT is consistent with AutoAWQ
Output:
```
<main>
2025-11-17T18:26:04.682699+0000 | _validate_recovery | INFO - ✓ exact_match,strict-match | Base: 0.7650 | Compressed: 0.7090 | Recovery: 92.68% ↑ | Threshold: ≥92.00%
2025-11-17T18:26:04.682811+0000 | _validate_recovery | INFO - ✓ exact_match,flexible-extract | Base: 0.7630 | Compressed: 0.7100 | Recovery: 93.05% ↑ | Threshold: ≥93.00%
<this branch>
2025-11-17T17:55:00.648672+0000 | _validate_recovery | ERROR - ✗ exact_match,strict-match | Base: 0.7650 | Compressed: 0.6950 | Recovery: 90.85% ↑ | Threshold: ≥92.00%
2025-11-17T17:55:00.648967+0000 | _validate_recovery | ERROR - ✗ exact_match,flexible-extract | Base: 0.7630 | Compressed: 0.6960 | Recovery: 91.22% ↑ | Threshold: ≥93.00%
```
This is already a pretty high drop in recovery, should we revisit this
test?
- [x] Further regression testing against main was done in this
[commit](8b6b0a5)
see
[run.sh](https://github.com/vllm-project/llm-compressor/blob/8b6b0a5ae27084756df5d7e3fd0eca60cbe07b87/run.sh)
as of that commit which was removed in the final PR. Results look
reasonable comparing branch and main, some up some down, within margin
of error.
Test Group Quantization (w4a16_awq_sym)
| Branch | Metric | Base | Compressed | Recovery |
|-----------|------------------------------|--------|------------|----------|
| On Branch | exact_match,strict-match | 0.7620 | 0.7170 | 94.09% ↑ |
| On Branch | exact_match,flexible-extract | 0.7600 | 0.7130 | 93.82% ↑
|
| On Main | exact_match,strict-match | 0.7620 | 0.7090 | 93.04% |
| On Main | exact_match,flexible-extract | 0.7600 | 0.7060 | 92.89% |
Test Tensor Quantization (int8_tensor)
| Branch | Metric | Base | Compressed | Recovery |
|-----------|------------------------------|--------|------------|----------|
| On Branch | exact_match,strict-match | 0.7620 | 0.7220 | 94.75% ↓ |
| On Branch | exact_match,flexible-extract | 0.7600 | 0.7240 | 95.26% ↓
|
| On Main | exact_match,strict-match | 0.7620 | 0.7280 | 95.54% |
| On Main | exact_match,flexible-extract | 0.7600 | 0.7310 | 96.18% |
Test Channel Quantization (fp8_dynamic)
| Branch | Metric | Base | Compressed | Recovery |
|-----------|------------------------------|--------|------------|----------|
| On Branch | exact_match,strict-match | 0.7650 | 0.7610 | 99.48% |
| On Branch | exact_match,flexible-extract | 0.7630 | 0.7580 | 99.34% |
Test Block Quantization (fp8_block)
| Branch | Metric | Base | Compressed | Recovery |
|-----------|------------------------------|--------|------------|-----------|
| On Branch | exact_match,strict-match | 0.7650 | 0.7720 | 100.92% |
| On Branch | exact_match,flexible-extract | 0.7630 | 0.7690 | 100.79% |
<details>
<summary>awq_oneshot.py script</summary>
```python
import os
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from llmcompressor import oneshot, active_session
from llmcompressor.utils import dispatch_for_generation
from llmcompressor.modifiers.awq import AWQModifier, AWQMapping
from llmcompressor.modifiers.quantization import QuantizationModifier
from compressed_tensors.quantization import (
QuantizationArgs,
QuantizationScheme,
QuantizationStrategy,
QuantizationType,
)
MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
SAVE_DIR = MODEL_ID.split("/")[-1] + "-awq-asym"
# Configure the quantization algorithm to run.
recipe = [
AWQModifier(
ignore=[
"lm_head",
"re:.*mlp.gate$",
"re:.*mlp.shared_expert_gate$",
"re:visual.*",
],
scheme="W4A16_ASYM",
duo_scaling="both",
targets=["Linear"],
# offload_device=torch.device("cpu"),
),
]
# Select calibration dataset.
DATASET_ID = "mit-han-lab/pile-val-backup"
DATASET_SPLIT = "validation"
# Select number of samples. 256 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 256
MAX_SEQUENCE_LENGTH = 512
def get_calib_dataset(tokenizer):
from datasets import load_dataset
ds = load_dataset(
DATASET_ID,
split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES*10}]",
)
def preprocess(example):
return {"input_ids": tokenizer.encode(example["text"].strip())}
ds = (
ds.shuffle(seed=42)
.map(preprocess, remove_columns=ds.column_names)
.select(range(NUM_CALIBRATION_SAMPLES))
)
return ds
if __name__ == "__main__":
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, torch_dtype="auto", trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
###
### Apply algorithms.
###
oneshot(
model=model,
dataset=get_calib_dataset(tokenizer),
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
log_dir=None,
trust_remote_code_model=True,
)
# Confirm generations of the quantized model look sane.
dispatch_for_generation(model)
print("\n\n")
print("========== SAMPLE GENERATION ==============")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")
# Save to disk compressed.
model.save_pretrained(SAVE_DIR)
tokenizer.save_pretrained(SAVE_DIR)
##
### Apply algorithms.
##
## LM EVAL
active_session().reset()
del model
del tokenizer
torch.cuda.empty_cache()
import lm_eval
from lm_eval.utils import make_table
results = lm_eval.simple_evaluate(
model="vllm",
model_args={
"pretrained": SAVE_DIR,
"add_bos_token": True,
"dtype": "bfloat16",
"gpu_memory_utilization": 0.7,
"max_model_len": 4096,
# "max_num_batched_tokens": 128,
# "max_num_seqs": 128,
},
tasks=["wikitext"],
batch_size=128,
)
print(make_table(results))
```
</details>
---------
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
Signed-off-by: HDCharles <[email protected]>
Co-authored-by: Brian Dellabetta <[email protected]>
Co-authored-by: HDCharles <[email protected]>
Co-authored-by: Fynn Schmitt-Ulms <[email protected]>1 parent 056ed3d commit 03e694a
File tree
4 files changed
+448
-275
lines changed- examples/awq
- src/llmcompressor/modifiers
- awq
- quantization
- tests/llmcompressor/modifiers/awq
4 files changed
+448
-275
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
50 | 50 | | |
51 | 51 | | |
52 | 52 | | |
53 | | - | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
54 | 56 | | |
55 | 57 | | |
56 | 58 | | |
| |||
0 commit comments