Skip to content

Commit 056ed3d

Browse files
kylesayrsdsikka
andauthored
[Performance] Batched calibration (#2054)
## Purpose ## * Reduce calibration runtime by providing users with options to increase performance * `batch_size` controls the batch size of calibration data * `offload_sequential_activations` controls whether calibration data is offloaded to the CPU between layers ## Prerequisites ## * #2080 * #2081 ## Changes ## ### Batched Calibration ### * Add `batch_size` argument * Change `data_collator` default from the default data collator to a `"truncation"` collator * The `data_collator_with_truncation` function truncates all samples to the shortest length sample in the batch. * Statistics about how many tokens are dropped using this method are in the tables below * The data collator can also be changed to "padding" instead to pad to the longest length sample in the batch * In order to reduce the amount of excess truncation/padding, default to `LengthAwareSampler` which samples from the dataset such that samples with similar batch lengths are batched together Batch Size | Time | % Speedup | % Deleted -- | -- | -- | -- Original (1) | 11m17 | N/A | 0.0 1 | 11m17 | 0.0 | 0.0 2 | 10m48 | 4.2 | 0.2 4 | 10m39 | 5.6 | 0.5 8 | 10m39 | 5.6 | 1.1 16 | 10m58 | 2.8 | 2.6 64 | 11m4 | 11.2 | 12.0 128 | 9m29 | 16.0 | 23.9 512 | 7m39 | 37.3 | 75.3 <!-- notionvc: 36bc5ab7-4968-4c6d-8f38-e5715769b9ba --> * The speedup is relatively meager up until you start deleting significant portions of the dataset via truncation ### Disable Offloading ### * Add `offload_sequential_activations` argument, defaults to True (no behavior change) * Enabling this option increases throughput but also increases memory usage Batch Size | Time | % Speedup | % Deleted -- | -- | -- | -- Original (1) | 11m17 | N/A | 0.0 1 | 10m14 | 9.3 | 0.0 2 | 9m46 | 13.4 | 0.2 4 | 9m36 | 14.9 | 0.5 8 | 9m48 | 13.1 | 1.1 16 | 9m26 | 16.3 | 2.6 32 | 9m27 | 16.2 | 5.8 128 | 8m34 | 24.0 | 23.9 512 | 6m40 | 40.9 | 75.3 <!-- notionvc: 3c954cd3-850c-412c-92b3-fa4cfa914be8 --> * Memory requirement for 512 samples on Llama 8B is ~70Gb, which is equivalent to batch size 128 * With this option enabled and batch size 32, calibration runtime is less than 1s per layer (down from ~11s) * This implies that the theoretical maximum speedup from reducing calibration time alone is ~15% for this model + dataset ### Misc ### * Fix examples * Fixed examples where there's issues between model dtypes and processor dtypes (Mixtral, Pixtral, Whisper) * For multimodal models which use multimodal datasets, remove their data collators, as the batch unwrapping is now done by the`TextGenerationDataset` * Remove `_mask_padding` from `IntermediatesCache`, as I do not believe that this method is effective in masking padding tokens from hessian calculations * Fix AWQ * AWQ was hard coded to handle only batches of size 1 ## Testing ## ### Evaluation Regression ### Batch Size | Eval Score | Difference | % Deleted -- | -- | -- | -- Original (1) | 0.6573 | 0.000 | 0.0 1 | 0.6513 | -0.6 | 0.0 2 | 0.6513 | -0.6 | 0.2 4 | 0.6657 | +0.8 | 0.5 8 | 0.6513 | -0.6 | 1.1 16 | 0.6672 | +1.0 | 2.6 64 | 0.6338 | -2.4 | 12.0 128 | 0.6603 | +0.3 | 23.9 512 | 0.6391 | -1.8 | 75.3 <!-- notionvc: c37244c6-0013-463e-9b1e-d82d6d78ebe1 --> Deleting significant portions of the dataset (delete longer sequences first) has a detrimental effect on recovery ### Modifiers ### * GPTQ * Ran full regression tests, as shown above * AWQ * Ran AWQ with batch size 32 and checked output sanity * Quantization Modifier * Ran NVFP4 with batch size 10 and checked output sanity ### Calibration Regression Testing ### I ran calibration for the following models (but did not evaluate recovery) The following model examples can calibrate without issue: * Llama3 * Gemma3 * Internvl3 * Mllama * Llama4 The following models had a bug where processor and model dtypes were mismatched, but are now fixed by this PR: * Mistral3 * Pixtral * Whisper The following models have an accelerate device offloading bug: * Idefics3 * Phi3 Vision The following model examples have an MoE replacement bug: * qwen3-vl-30b-a3b-Instruct ## Future Work ## While these options are a great place to start, the next step to improve runtime is to allow multi-GPU compression, likely via torch.distributed tensor parallelism --------- Signed-off-by: Kyle Sayers <[email protected]> Co-authored-by: Dipika Sikka <[email protected]>
1 parent 6cf8d29 commit 056ed3d

File tree

17 files changed

+259
-203
lines changed

17 files changed

+259
-203
lines changed

examples/multimodal_audio/whisper_example.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import torch
22
from datasets import load_dataset
3-
from transformers import WhisperForConditionalGeneration, WhisperProcessor
3+
from transformers import (
4+
WhisperForConditionalGeneration,
5+
WhisperProcessor,
6+
default_data_collator,
7+
)
48

59
from llmcompressor import oneshot
610
from llmcompressor.modifiers.quantization import GPTQModifier
@@ -55,20 +59,27 @@ def process(sample):
5559
return_tensors="pt",
5660
)
5761

58-
inputs["input_features"] = inputs["input_features"].to(dtype=model.dtype)
62+
# treat labels as calibration prefill
5963
inputs["decoder_input_ids"] = inputs["labels"]
6064
del inputs["labels"]
6165

66+
# strip extra dim added by multimodal processors
67+
inputs = {key: value[0] for key, value in inputs.items()}
68+
6269
return inputs
6370

6471

6572
ds = ds.map(process, remove_columns=ds.column_names)
6673

6774

68-
# Define a oneshot data collator for multimodal inputs.
69-
def data_collator(batch):
70-
assert len(batch) == 1
71-
return {key: torch.tensor(value) for key, value in batch[0].items()}
75+
# Patch: mismatch between processor and model dtype
76+
def data_collator(features):
77+
for feature in features:
78+
feature["input_features"] = torch.tensor(
79+
feature["input_features"], dtype=model.dtype
80+
)
81+
82+
return default_data_collator(features, return_tensors="pt")
7283

7384

7485
# Recipe

examples/multimodal_vision/gemma3_example.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import requests
2-
import torch
32
from PIL import Image
43
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
54

@@ -13,17 +12,11 @@
1312
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
1413

1514
# Oneshot arguments
16-
DATASET_ID = "flickr30k"
17-
DATASET_SPLIT = {"calibration": "test[:512]"}
15+
BATCH_SIZE = 4
1816
NUM_CALIBRATION_SAMPLES = 512
1917
MAX_SEQUENCE_LENGTH = 2048
20-
21-
22-
# Define a oneshot data collator for multimodal inputs.
23-
def data_collator(batch):
24-
assert len(batch) == 1
25-
return {key: torch.tensor(value) for key, value in batch[0].items()}
26-
18+
DATASET_ID = "flickr30k"
19+
DATASET_SPLIT = {"calibration": f"test[:{NUM_CALIBRATION_SAMPLES}]"}
2720

2821
# Recipe
2922
recipe = [
@@ -41,14 +34,13 @@ def data_collator(batch):
4134
# Perform oneshot
4235
oneshot(
4336
model=model,
44-
tokenizer=model_id,
37+
processor=processor,
4538
dataset=DATASET_ID,
4639
splits=DATASET_SPLIT,
4740
recipe=recipe,
41+
batch_size=BATCH_SIZE,
4842
max_seq_length=MAX_SEQUENCE_LENGTH,
4943
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
50-
trust_remote_code_model=True,
51-
data_collator=data_collator,
5244
)
5345

5446
# Confirm generations of the quantized model look sane.

examples/multimodal_vision/internvl3_example.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,20 +37,14 @@ def preprocess_and_tokenize(example):
3737
return_dict=True,
3838
return_tensors="pt",
3939
)
40-
return inputs
41-
4240

43-
ds = ds.map(preprocess_and_tokenize)
41+
# remove extra dim added by multimodal processors
42+
inputs = {key: value[0] for key, value in inputs.items()}
4443

44+
return inputs
4545

46-
def data_collator(batch):
47-
assert len(batch) == 1
48-
item = {key: value for key, value in batch[0].items()}
49-
item["attention_mask"] = torch.tensor([item["attention_mask"]])
50-
item["input_ids"] = torch.LongTensor([item["input_ids"]])
51-
52-
return item
5346

47+
ds = ds.map(preprocess_and_tokenize, remove_columns=ds.column_names)
5448

5549
# Recipe
5650
recipe = GPTQModifier(
@@ -68,7 +62,6 @@ def data_collator(batch):
6862
max_seq_length=MAX_SEQUENCE_LENGTH,
6963
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
7064
trust_remote_code_model=True,
71-
data_collator=data_collator,
7265
)
7366

7467
# Save to disk compressed.

examples/multimodal_vision/llava_example.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import requests
2-
import torch
32
from PIL import Image
43
from transformers import AutoProcessor, LlavaForConditionalGeneration
54

@@ -19,12 +18,6 @@
1918
MAX_SEQUENCE_LENGTH = 2048
2019

2120

22-
# Define a oneshot data collator for multimodal inputs.
23-
def data_collator(batch):
24-
assert len(batch) == 1
25-
return {key: torch.tensor(value) for key, value in batch[0].items()}
26-
27-
2821
# Recipe
2922
recipe = [
3023
GPTQModifier(
@@ -44,7 +37,6 @@ def data_collator(batch):
4437
max_seq_length=MAX_SEQUENCE_LENGTH,
4538
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
4639
trust_remote_code_model=True,
47-
data_collator=data_collator,
4840
sequential_targets=["LlamaDecoderLayer"],
4941
)
5042

examples/multimodal_vision/mistral3_example.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
import requests
55
import torch
66
from PIL import Image
7-
from transformers import AutoProcessor, Mistral3ForConditionalGeneration
7+
from transformers import (
8+
AutoProcessor,
9+
Mistral3ForConditionalGeneration,
10+
default_data_collator,
11+
)
812

913
from llmcompressor import oneshot
1014
from llmcompressor.modifiers.quantization import GPTQModifier
@@ -27,17 +31,13 @@
2731
MAX_SEQUENCE_LENGTH = 2048
2832

2933

30-
# Define a oneshot data collator for multimodal inputs.
31-
def data_collator(batch):
32-
assert len(batch) == 1
33-
return {
34-
key: (
35-
torch.tensor(value)
36-
if key != "pixel_values"
37-
else torch.tensor(value, dtype=model.dtype)
34+
# Patch: mismatch between processor and model dtype
35+
def data_collator(features):
36+
for feature in features:
37+
feature["pixel_values"] = torch.tensor(
38+
feature["pixel_values"], dtype=model.dtype
3839
)
39-
for key, value in batch[0].items()
40-
}
40+
return default_data_collator(features, return_tensors="pt")
4141

4242

4343
# Recipe

examples/multimodal_vision/mllama_example.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import requests
2-
import torch
32
from PIL import Image
43
from transformers import AutoProcessor, MllamaForConditionalGeneration
54

@@ -19,12 +18,6 @@
1918
MAX_SEQUENCE_LENGTH = 2048
2019

2120

22-
# Define a oneshot data collator for multimodal inputs.
23-
def data_collator(batch):
24-
assert len(batch) == 1
25-
return {key: torch.tensor(value) for key, value in batch[0].items()}
26-
27-
2821
# Recipe
2922
recipe = [
3023
GPTQModifier(
@@ -44,7 +37,6 @@ def data_collator(batch):
4437
max_seq_length=MAX_SEQUENCE_LENGTH,
4538
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
4639
trust_remote_code_model=True,
47-
data_collator=data_collator,
4840
sequential_targets=["MllamaSelfAttentionDecoderLayer"],
4941
)
5042

examples/multimodal_vision/pixtral_example.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
import requests
22
import torch
33
from PIL import Image
4-
from transformers import AutoProcessor, LlavaForConditionalGeneration
4+
from transformers import (
5+
AutoProcessor,
6+
LlavaForConditionalGeneration,
7+
default_data_collator,
8+
)
59

610
from llmcompressor import oneshot
711
from llmcompressor.modifiers.quantization import GPTQModifier
@@ -19,16 +23,13 @@
1923
MAX_SEQUENCE_LENGTH = 2048
2024

2125

22-
# Define a oneshot data collator for multimodal inputs.
23-
# NOTE: for transformers<4.48.0, please squeeze the first dimension of `pixel_values`
24-
# by appending `[0]` to the end of line 32
25-
def data_collator(batch):
26-
assert len(batch) == 1
27-
return {
28-
"input_ids": torch.LongTensor(batch[0]["input_ids"]),
29-
"attention_mask": torch.tensor(batch[0]["attention_mask"]),
30-
"pixel_values": torch.tensor(batch[0]["pixel_values"]),
31-
}
26+
# Patch: mismatch between processor and model dtype
27+
def data_collator(features):
28+
for feature in features:
29+
feature["pixel_values"] = torch.tensor(
30+
feature["pixel_values"], dtype=model.dtype
31+
)
32+
return default_data_collator(features, return_tensors="pt")
3233

3334

3435
# Recipe
@@ -46,11 +47,11 @@ def data_collator(batch):
4647
tokenizer=model_id,
4748
dataset=DATASET_ID,
4849
splits={"calibration": f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]"},
50+
data_collator=data_collator,
4951
recipe=recipe,
5052
max_seq_length=MAX_SEQUENCE_LENGTH,
5153
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
5254
trust_remote_code_model=True,
53-
data_collator=data_collator,
5455
sequential_targets=["MistralDecoderLayer"],
5556
)
5657

src/llmcompressor/args/dataset_arguments.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@
88
"""
99

1010
from dataclasses import dataclass, field
11-
from typing import Any, Callable
12-
13-
from transformers import DefaultDataCollator
11+
from typing import Callable
1412

1513

1614
@dataclass
@@ -69,9 +67,27 @@ class CustomDatasetArguments(DVCDatasetArguments):
6967
},
7068
)
7169

72-
data_collator: Callable[[Any], Any] = field(
73-
default_factory=lambda: DefaultDataCollator(),
74-
metadata={"help": "The function to used to form a batch from the dataset"},
70+
batch_size: int = field(
71+
default=1,
72+
metadata={
73+
"help": (
74+
"Calibration batch size. During calibration, LLM Compressor disables "
75+
"lm_head output computations to reduce memory usage from large "
76+
"batch sizes. Large batch sizes may result in excess padding or "
77+
"truncation, depending on the data_collator"
78+
)
79+
},
80+
)
81+
82+
data_collator: str | Callable = field(
83+
default="truncation",
84+
metadata={
85+
"help": (
86+
"The function to used to form a batch from the dataset. Can also "
87+
"specify 'truncation' or 'padding' to truncate or pad non-uniform "
88+
"sequence lengths in a batch. Defaults to 'padding'."
89+
)
90+
},
7591
)
7692

7793

@@ -126,8 +142,8 @@ class DatasetArguments(CustomDatasetArguments):
126142
default=512,
127143
metadata={"help": "Number of samples to use for one-shot calibration"},
128144
)
129-
shuffle_calibration_samples: bool | None = field(
130-
default=True,
145+
shuffle_calibration_samples: bool = field(
146+
default=False,
131147
metadata={
132148
"help": "whether to shuffle the dataset before selecting calibration data"
133149
},
@@ -142,7 +158,7 @@ class DatasetArguments(CustomDatasetArguments):
142158
)
143159
preprocessing_num_workers: int | None = field(
144160
default=None,
145-
metadata={"help": "The number of processes to use for the preprocessing."},
161+
metadata={"help": "The number of workers to use for dataset processing."},
146162
)
147163
pad_to_max_length: bool = field(
148164
default=True,
@@ -214,6 +230,14 @@ class DatasetArguments(CustomDatasetArguments):
214230
"definition"
215231
},
216232
)
233+
offload_sequential_activations: bool = field(
234+
default=True,
235+
metadata={
236+
"help": "Whether to offload intermediate activations between sequential "
237+
"layers to the CPU. Disabling offloading is much faster, but uses "
238+
"signficiantly more memory. Default is True."
239+
},
240+
)
217241
quantization_aware_calibration: bool = field(
218242
default=True,
219243
metadata={

0 commit comments

Comments
 (0)