1414
1515"""Functional tests for QAT (Quantization Aware Training) workflow."""
1616
17+ import os
1718import subprocess
1819from pathlib import Path
1920
2021import pytest
2122
23+ from megatron .bridge .training .utils .checkpoint_utils import (
24+ TRACKER_PREFIX ,
25+ get_checkpoint_name ,
26+ get_checkpoint_tracker_filename ,
27+ get_checkpoint_train_state_filename ,
28+ )
2229from tests .functional_tests .utils import clear_directories
2330
2431
@@ -98,6 +105,8 @@ def _run_pretrain_from_quantized_checkpoint(
98105 tp : int = 1 ,
99106 pp : int = 1 ,
100107 cp : int = 2 ,
108+ train_iters : int = 10 ,
109+ save_interval : int = 10 ,
101110 ):
102111 """
103112 Run pre-training from a quantized checkpoint using subprocess.
@@ -109,9 +118,12 @@ def _run_pretrain_from_quantized_checkpoint(
109118 tp: Tensor parallelism size
110119 pp: Pipeline parallelism size
111120 cp: Context parallelism size (default: 2)
121+ train_iters: Number of training iterations
122+ save_interval: Interval for saving checkpoints
112123
113124 Returns:
114- subprocess.CompletedProcess: The result of the subprocess run
125+ tuple: (subprocess.CompletedProcess, final_iteration)
126+ where final_iteration is the last checkpoint saved
115127 """
116128 # Calculate total number of processes needed (tp * pp * cp)
117129 total_procs = tp * pp * cp
@@ -121,6 +133,10 @@ def _run_pretrain_from_quantized_checkpoint(
121133
122134 python_executable = sys .executable
123135
136+ # Calculate the final iteration (last checkpoint that will be saved)
137+ # Checkpoints are saved at intervals, so the last one is at train_iters if it's a multiple of save_interval
138+ final_iteration = (train_iters // save_interval ) * save_interval
139+
124140 # Base command for pre-training from quantized checkpoint
125141 cmd = [
126142 python_executable ,
@@ -139,13 +155,13 @@ def _run_pretrain_from_quantized_checkpoint(
139155 "model.gradient_accumulation_fusion=False" ,
140156 f"checkpoint.pretrained_checkpoint={ quantized_checkpoint_path } " ,
141157 f"checkpoint.save={ checkpoint_save_dir } " ,
142- "checkpoint.save_interval=10 " ,
143- "train.train_iters=10 " ,
158+ f "checkpoint.save_interval={ save_interval } " ,
159+ f "train.train_iters={ train_iters } " ,
144160 "train.eval_interval=5" ,
145161 "train.eval_iters=2" ,
146162 "train.global_batch_size=8" ,
147163 "scheduler.lr_warmup_iters=2" ,
148- "scheduler.lr_decay_iters=10 " ,
164+ f "scheduler.lr_decay_iters={ train_iters } " ,
149165 ]
150166
151167 # Always add parallelism arguments to override script defaults
@@ -154,7 +170,7 @@ def _run_pretrain_from_quantized_checkpoint(
154170 cmd .append (f"model.context_parallel_size={ cp } " )
155171
156172 result = subprocess .run (cmd , capture_output = True , text = True , cwd = Path (__file__ ).parent .parent .parent .parent )
157- return result
173+ return result , final_iteration
158174
159175 @pytest .mark .run_only_on ("GPU" )
160176 @pytest .mark .parametrize ("recipe_name,parallelism_overrides" , QAT_WORKFLOW_CONFIGS )
@@ -212,13 +228,17 @@ def test_qat_workflow(self, recipe_name, parallelism_overrides, tmp_path):
212228
213229 print (f"=== STEP 2: Running pre-training from quantized checkpoint for { recipe_name } ===" )
214230 # Step 2: Run pre-training from the quantized checkpoint
215- pretrain_result = self ._run_pretrain_from_quantized_checkpoint (
231+ train_iters = 10
232+ save_interval = 10
233+ pretrain_result , expected_iteration = self ._run_pretrain_from_quantized_checkpoint (
216234 quantized_checkpoint_path = str (quantized_checkpoint_dir ),
217235 checkpoint_save_dir = str (checkpoint_save_dir ),
218236 hf_model_id = "meta-llama/Llama-3.2-1B" ,
219237 tp = tensor_model_parallel_size or 1 ,
220238 pp = pipeline_model_parallel_size or 1 ,
221239 cp = context_parallel_size or 2 , # Default context parallelism is 2
240+ train_iters = train_iters ,
241+ save_interval = save_interval ,
222242 )
223243
224244 if pretrain_result .returncode != 0 :
@@ -227,12 +247,63 @@ def test_qat_workflow(self, recipe_name, parallelism_overrides, tmp_path):
227247 assert False , f"Pre-training step failed with return code { pretrain_result .returncode } "
228248
229249 print ("✓ Pre-training from quantized checkpoint completed successfully" )
250+ print (f" Training ran for { train_iters } iterations, saving every { save_interval } iterations" )
251+ print (f" Expected final checkpoint iteration: { expected_iteration } " )
230252
231- # Verify checkpoint files were created (simple existence check, not full distributed verification)
253+ # Verify checkpoint files were created with comprehensive checks
254+ # (adapted from verify_checkpoint_files but without requiring torch.distributed)
232255 assert checkpoint_save_dir .exists (), f"Checkpoint save directory not found at { checkpoint_save_dir } "
233- checkpoint_dirs = list (checkpoint_save_dir .iterdir ())
234- assert len (checkpoint_dirs ) > 0 , f"No checkpoints saved in { checkpoint_save_dir } "
235- print (f"✓ Checkpoint files verified: { [d .name for d in checkpoint_dirs ]} " )
256+
257+ # Verify Megatron-Bridge tracker file
258+ latest_tracker_file = get_checkpoint_train_state_filename (str (checkpoint_save_dir ), prefix = TRACKER_PREFIX )
259+ assert os .path .exists (latest_tracker_file ), (
260+ f"Latest checkpoint tracker file not found at { latest_tracker_file } "
261+ )
262+ print (f"✓ Megatron-Bridge tracker file found: { latest_tracker_file } " )
263+
264+ # Verify Megatron-LM compatibility tracker file
265+ megatron_lm_tracker = get_checkpoint_tracker_filename (str (checkpoint_save_dir ))
266+ assert os .path .exists (megatron_lm_tracker ), f"Megatron-LM tracker file not found at { megatron_lm_tracker } "
267+ print (f"✓ Megatron-LM tracker file found: { megatron_lm_tracker } " )
268+
269+ # Verify the tracker file contains the correct iteration
270+ with open (megatron_lm_tracker , "r" ) as f :
271+ saved_iteration = f .read ().strip ()
272+ assert saved_iteration == str (expected_iteration ), (
273+ f"Megatron-LM tracker file contains '{ saved_iteration } ', expected '{ expected_iteration } '"
274+ )
275+ print (f"✓ Tracker file contains correct iteration: { expected_iteration } " )
276+
277+ # Verify final checkpoint directory exists
278+ final_iter_dir = get_checkpoint_name (str (checkpoint_save_dir ), expected_iteration , release = False )
279+ assert os .path .exists (final_iter_dir ), f"Final checkpoint directory not found at { final_iter_dir } "
280+ print (f"✓ Final checkpoint directory found: { final_iter_dir } " )
281+
282+ # Verify metadata file exists
283+ metadata_file = os .path .join (final_iter_dir , ".metadata" )
284+ assert os .path .exists (metadata_file ), f"Checkpoint metadata file not found at { metadata_file } "
285+ print (f"✓ Metadata file found: { metadata_file } " )
286+
287+ # Verify .distcp files (torch.distributed.checkpoint format)
288+ distcp_files = [f for f in os .listdir (final_iter_dir ) if f .endswith (".distcp" )]
289+
290+ # Calculate expected world size from parallelism settings
291+ tp = tensor_model_parallel_size or 1
292+ pp = pipeline_model_parallel_size or 1
293+ cp = context_parallel_size or 2
294+ world_size = tp * pp * cp
295+
296+ # For torch_dist format, expect 2 * world_size .distcp files
297+ # (one for model state, one for optimizer state per rank)
298+ expected_distcp_files = 2 * world_size
299+ assert len (distcp_files ) == expected_distcp_files , (
300+ f"Expected { expected_distcp_files } .distcp files (2 * { world_size } world_size), "
301+ f"found { len (distcp_files )} : { distcp_files } "
302+ )
303+ print (
304+ f"✓ Correct number of .distcp files: { len (distcp_files )} "
305+ f"(world_size={ world_size } , tp={ tp } , pp={ pp } , cp={ cp } )"
306+ )
236307
237308 print (f"SUCCESS: Complete QAT workflow test passed for { recipe_name } " )
238309
0 commit comments