Skip to content

Commit 6da9242

Browse files
committed
Bug fix for ptq_generate, add functional test for ptq and ptq_generate
Signed-off-by: James Shen <[email protected]>
1 parent df38da1 commit 6da9242

File tree

4 files changed

+98
-19
lines changed

4 files changed

+98
-19
lines changed

examples/quantization/ptq_generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def main(
9292
model_provider.initialize_model_parallel(seed=0)
9393
megatron_model = bridge.load_megatron_model(
9494
megatron_load_path,
95-
mp_override={
95+
mp_overrides={
9696
"tensor_model_parallel_size": tp,
9797
"pipeline_model_parallel_size": pp,
9898
"expert_model_parallel_size": ep,

tests/functional_tests/L2_Launch_quantization.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ export CUDA_VISIBLE_DEVICES="0,1"
2222

2323
uv run coverage run --data-file=/opt/Megatron-Bridge/.coverage --source=/opt/Megatron-Bridge/ --parallel-mode -m pytest \
2424
-o log_cli=true -o log_cli_level=INFO -v -s -x -m "not pleasefixme" --tb=short -rA \
25-
tests/functional_tests/quantization/test_qat_workflow.py
25+
tests/functional_tests/quantization/test_qat_workflow.py \
26+
tests/functional_tests/quantization/test_quantization_workflow.py
2627
coverage combine -q
2728

tests/functional_tests/quantization/test_qat_workflow.py

Lines changed: 81 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,18 @@
1414

1515
"""Functional tests for QAT (Quantization Aware Training) workflow."""
1616

17+
import os
1718
import subprocess
1819
from pathlib import Path
1920

2021
import pytest
2122

23+
from megatron.bridge.training.utils.checkpoint_utils import (
24+
TRACKER_PREFIX,
25+
get_checkpoint_name,
26+
get_checkpoint_tracker_filename,
27+
get_checkpoint_train_state_filename,
28+
)
2229
from tests.functional_tests.utils import clear_directories
2330

2431

@@ -98,6 +105,8 @@ def _run_pretrain_from_quantized_checkpoint(
98105
tp: int = 1,
99106
pp: int = 1,
100107
cp: int = 2,
108+
train_iters: int = 10,
109+
save_interval: int = 10,
101110
):
102111
"""
103112
Run pre-training from a quantized checkpoint using subprocess.
@@ -109,9 +118,12 @@ def _run_pretrain_from_quantized_checkpoint(
109118
tp: Tensor parallelism size
110119
pp: Pipeline parallelism size
111120
cp: Context parallelism size (default: 2)
121+
train_iters: Number of training iterations
122+
save_interval: Interval for saving checkpoints
112123
113124
Returns:
114-
subprocess.CompletedProcess: The result of the subprocess run
125+
tuple: (subprocess.CompletedProcess, final_iteration)
126+
where final_iteration is the last checkpoint saved
115127
"""
116128
# Calculate total number of processes needed (tp * pp * cp)
117129
total_procs = tp * pp * cp
@@ -121,6 +133,10 @@ def _run_pretrain_from_quantized_checkpoint(
121133

122134
python_executable = sys.executable
123135

136+
# Calculate the final iteration (last checkpoint that will be saved)
137+
# Checkpoints are saved at intervals, so the last one is at train_iters if it's a multiple of save_interval
138+
final_iteration = (train_iters // save_interval) * save_interval
139+
124140
# Base command for pre-training from quantized checkpoint
125141
cmd = [
126142
python_executable,
@@ -139,13 +155,13 @@ def _run_pretrain_from_quantized_checkpoint(
139155
"model.gradient_accumulation_fusion=False",
140156
f"checkpoint.pretrained_checkpoint={quantized_checkpoint_path}",
141157
f"checkpoint.save={checkpoint_save_dir}",
142-
"checkpoint.save_interval=10",
143-
"train.train_iters=10",
158+
f"checkpoint.save_interval={save_interval}",
159+
f"train.train_iters={train_iters}",
144160
"train.eval_interval=5",
145161
"train.eval_iters=2",
146162
"train.global_batch_size=8",
147163
"scheduler.lr_warmup_iters=2",
148-
"scheduler.lr_decay_iters=10",
164+
f"scheduler.lr_decay_iters={train_iters}",
149165
]
150166

151167
# Always add parallelism arguments to override script defaults
@@ -154,7 +170,7 @@ def _run_pretrain_from_quantized_checkpoint(
154170
cmd.append(f"model.context_parallel_size={cp}")
155171

156172
result = subprocess.run(cmd, capture_output=True, text=True, cwd=Path(__file__).parent.parent.parent.parent)
157-
return result
173+
return result, final_iteration
158174

159175
@pytest.mark.run_only_on("GPU")
160176
@pytest.mark.parametrize("recipe_name,parallelism_overrides", QAT_WORKFLOW_CONFIGS)
@@ -212,13 +228,17 @@ def test_qat_workflow(self, recipe_name, parallelism_overrides, tmp_path):
212228

213229
print(f"=== STEP 2: Running pre-training from quantized checkpoint for {recipe_name} ===")
214230
# Step 2: Run pre-training from the quantized checkpoint
215-
pretrain_result = self._run_pretrain_from_quantized_checkpoint(
231+
train_iters = 10
232+
save_interval = 10
233+
pretrain_result, expected_iteration = self._run_pretrain_from_quantized_checkpoint(
216234
quantized_checkpoint_path=str(quantized_checkpoint_dir),
217235
checkpoint_save_dir=str(checkpoint_save_dir),
218236
hf_model_id="meta-llama/Llama-3.2-1B",
219237
tp=tensor_model_parallel_size or 1,
220238
pp=pipeline_model_parallel_size or 1,
221239
cp=context_parallel_size or 2, # Default context parallelism is 2
240+
train_iters=train_iters,
241+
save_interval=save_interval,
222242
)
223243

224244
if pretrain_result.returncode != 0:
@@ -227,12 +247,63 @@ def test_qat_workflow(self, recipe_name, parallelism_overrides, tmp_path):
227247
assert False, f"Pre-training step failed with return code {pretrain_result.returncode}"
228248

229249
print("✓ Pre-training from quantized checkpoint completed successfully")
250+
print(f" Training ran for {train_iters} iterations, saving every {save_interval} iterations")
251+
print(f" Expected final checkpoint iteration: {expected_iteration}")
230252

231-
# Verify checkpoint files were created (simple existence check, not full distributed verification)
253+
# Verify checkpoint files were created with comprehensive checks
254+
# (adapted from verify_checkpoint_files but without requiring torch.distributed)
232255
assert checkpoint_save_dir.exists(), f"Checkpoint save directory not found at {checkpoint_save_dir}"
233-
checkpoint_dirs = list(checkpoint_save_dir.iterdir())
234-
assert len(checkpoint_dirs) > 0, f"No checkpoints saved in {checkpoint_save_dir}"
235-
print(f"✓ Checkpoint files verified: {[d.name for d in checkpoint_dirs]}")
256+
257+
# Verify Megatron-Bridge tracker file
258+
latest_tracker_file = get_checkpoint_train_state_filename(str(checkpoint_save_dir), prefix=TRACKER_PREFIX)
259+
assert os.path.exists(latest_tracker_file), (
260+
f"Latest checkpoint tracker file not found at {latest_tracker_file}"
261+
)
262+
print(f"✓ Megatron-Bridge tracker file found: {latest_tracker_file}")
263+
264+
# Verify Megatron-LM compatibility tracker file
265+
megatron_lm_tracker = get_checkpoint_tracker_filename(str(checkpoint_save_dir))
266+
assert os.path.exists(megatron_lm_tracker), f"Megatron-LM tracker file not found at {megatron_lm_tracker}"
267+
print(f"✓ Megatron-LM tracker file found: {megatron_lm_tracker}")
268+
269+
# Verify the tracker file contains the correct iteration
270+
with open(megatron_lm_tracker, "r") as f:
271+
saved_iteration = f.read().strip()
272+
assert saved_iteration == str(expected_iteration), (
273+
f"Megatron-LM tracker file contains '{saved_iteration}', expected '{expected_iteration}'"
274+
)
275+
print(f"✓ Tracker file contains correct iteration: {expected_iteration}")
276+
277+
# Verify final checkpoint directory exists
278+
final_iter_dir = get_checkpoint_name(str(checkpoint_save_dir), expected_iteration, release=False)
279+
assert os.path.exists(final_iter_dir), f"Final checkpoint directory not found at {final_iter_dir}"
280+
print(f"✓ Final checkpoint directory found: {final_iter_dir}")
281+
282+
# Verify metadata file exists
283+
metadata_file = os.path.join(final_iter_dir, ".metadata")
284+
assert os.path.exists(metadata_file), f"Checkpoint metadata file not found at {metadata_file}"
285+
print(f"✓ Metadata file found: {metadata_file}")
286+
287+
# Verify .distcp files (torch.distributed.checkpoint format)
288+
distcp_files = [f for f in os.listdir(final_iter_dir) if f.endswith(".distcp")]
289+
290+
# Calculate expected world size from parallelism settings
291+
tp = tensor_model_parallel_size or 1
292+
pp = pipeline_model_parallel_size or 1
293+
cp = context_parallel_size or 2
294+
world_size = tp * pp * cp
295+
296+
# For torch_dist format, expect 2 * world_size .distcp files
297+
# (one for model state, one for optimizer state per rank)
298+
expected_distcp_files = 2 * world_size
299+
assert len(distcp_files) == expected_distcp_files, (
300+
f"Expected {expected_distcp_files} .distcp files (2 * {world_size} world_size), "
301+
f"found {len(distcp_files)}: {distcp_files}"
302+
)
303+
print(
304+
f"✓ Correct number of .distcp files: {len(distcp_files)} "
305+
f"(world_size={world_size}, tp={tp}, pp={pp}, cp={cp})"
306+
)
236307

237308
print(f"SUCCESS: Complete QAT workflow test passed for {recipe_name}")
238309

tests/functional_tests/quantization/test_quantization_workflow.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,12 @@ def test_quantization_and_generation_single_gpu(self, tmp_path):
159159
assert False, f"Generation step failed with return code {generation_result.returncode}"
160160

161161
# Verify generation succeeded
162-
assert f"Loaded quantized model from: {quantized_checkpoint_dir}" in generation_result.stdout, (
163-
f"Checkpoint loading message not found. Output: {generation_result.stdout}"
164-
)
162+
# Note: stdout may have line wrapping, so we normalize it by removing newlines within the output
163+
stdout_normalized = generation_result.stdout.replace("\n", " ")
164+
assert (
165+
"Loaded quantized model from:" in generation_result.stdout
166+
and str(quantized_checkpoint_dir) in stdout_normalized
167+
), f"Checkpoint loading message not found. Output: {generation_result.stdout}"
165168
assert "Testing quantized model with custom prompts" in generation_result.stdout, (
166169
f"Generation test message not found. Output: {generation_result.stdout}"
167170
)
@@ -180,9 +183,10 @@ def test_quantization_and_generation_single_gpu(self, tmp_path):
180183
@pytest.mark.parametrize(
181184
"quant_tp,quant_pp,gen_tp,gen_pp,test_name",
182185
[
186+
(1, 1, 2, 1, "TP1_to_TP2"), # quantize with tp=1, generate with tp=2
183187
(2, 1, 1, 1, "TP2_to_Single"), # quantize with tp=2, generate with tp=1
184188
(1, 1, 1, 2, "PP1_to_PP2"), # quantize with pp=1, generate with pp=2
185-
(1, 2, 1, 1, "PP2_to_Single"), # additional: quantize pp=2, generate single
189+
(1, 2, 1, 1, "PP2_to_Single"), # quantize pp=2, generate single
186190
],
187191
)
188192
def test_quantization_and_generation_parallelism(self, tmp_path, quant_tp, quant_pp, gen_tp, gen_pp, test_name):
@@ -242,9 +246,12 @@ def test_quantization_and_generation_parallelism(self, tmp_path, quant_tp, quant
242246
assert False, f"Generation step for {test_name} failed with return code {generation_result.returncode}"
243247

244248
# Verify generation succeeded with correct parallelism
245-
assert f"Loaded quantized model from: {quantized_checkpoint_dir}" in generation_result.stdout, (
246-
f"Checkpoint loading message not found in {test_name}. Output: {generation_result.stdout}"
247-
)
249+
# Note: stdout may have line wrapping, so we normalize it by removing newlines within the output
250+
stdout_normalized = generation_result.stdout.replace("\n", " ")
251+
assert (
252+
"Loaded quantized model from:" in generation_result.stdout
253+
and str(quantized_checkpoint_dir) in stdout_normalized
254+
), f"Checkpoint loading message not found in {test_name}. Output: {generation_result.stdout}"
248255
assert f"Tensor parallel size: {gen_tp}" in generation_result.stdout, (
249256
f"Generation TP setting not found in {test_name}. Output: {generation_result.stdout}"
250257
)

0 commit comments

Comments
 (0)