1313# limitations under the License.
1414
1515import functools
16+ import os
1617import gc
1718import inspect
1819import logging
3031 PreTrainedModel ,
3132)
3233from transformers .models .auto .auto_factory import _BaseAutoModelClass
34+ from transformers .utils .hub import TRANSFORMERS_CACHE
3335
3436from nemo_automodel import __version__
3537from nemo_automodel ._transformers .registry import ModelRegistry
3840from nemo_automodel .components .distributed .fsdp2 import FSDP2Manager
3941from nemo_automodel .components .distributed .ddp import DDPManager
4042from torch .distributed .device_mesh import DeviceMesh
43+ from nemo_automodel .components .checkpoint .checkpointing import Checkpointer , CheckpointingConfig
4144
4245HAS_LIGER_KERNEL , liger_kernel_trf = safe_import ("liger_kernel.transformers" )
4346logger = logging .getLogger (__name__ )
@@ -302,19 +305,21 @@ def _retry(**override):
302305 config , * model_args , ** kwargs
303306 )
304307 logger .info (f"Using custom model implementation for { config .architectures [0 ]} " )
305- return model
308+ else :
309+ model = None
306310 except Exception as e :
307311 logger .error (f"Failed to use custom model implementation with error: { e } " )
308312
309313 if quantization_config is not None :
310314 kwargs ["quantization_config" ] = quantization_config
311- model = super ().from_pretrained (
312- pretrained_model_name_or_path ,
313- * model_args ,
314- torch_dtype = torch_dtype ,
315- attn_implementation = attn_implementation ,
316- ** kwargs ,
317- )
315+ if model is None :
316+ model = super ().from_pretrained (
317+ pretrained_model_name_or_path ,
318+ * model_args ,
319+ torch_dtype = torch_dtype ,
320+ attn_implementation = attn_implementation ,
321+ ** kwargs ,
322+ )
318323 cls .__name__ = name
319324 except ValueError as e :
320325 if "does not support" in str (e ):
@@ -325,6 +330,12 @@ def _retry(**override):
325330 return _retry (attn_implementation = attn_implementation )
326331 raise e
327332
333+ # Copy state dict keys before any parallelization (used by checkpointing)
334+ try :
335+ state_dict_keys_before_parallel = list (model .state_dict ().keys ())
336+ except Exception :
337+ state_dict_keys_before_parallel = []
338+
328339 # Decide kernel patching based on requested parallelism
329340 if distributed is not None :
330341 conf = distributed if isinstance (distributed , dict ) else {}
@@ -333,6 +344,14 @@ def _retry(**override):
333344 if tp > 1 or cp > 1 :
334345 use_liger_kernel = False
335346
347+ # Guard for CP requiring SDPA support on some models (moved from train_ft)
348+ try :
349+ if cp > 1 and hasattr (model , "_supports_sdpa" ) and model ._supports_sdpa is False :
350+ raise ValueError ("Model does not support SDPA required for context parallelism" )
351+ except Exception :
352+ # If attribute missing, do not block
353+ pass
354+
336355 # Build internal wrapper (FSDP2 or DDP) using device_mesh or config
337356 internal_wrapper = None
338357 if torch .distributed .is_initialized ():
@@ -400,6 +419,46 @@ def _retry(**override):
400419 dev = _choose_device (device )
401420 _move_module_to_device (model , dev , torch_dtype )
402421
422+ # Auto-detect local NeMo checkpoint path and load via Checkpointer
423+ try :
424+ is_local_path = isinstance (pretrained_model_name_or_path , str ) and os .path .isdir (pretrained_model_name_or_path )
425+ ckpt_model_dir = (
426+ os .path .join (pretrained_model_name_or_path , "model" ) if is_local_path else None
427+ )
428+ if ckpt_model_dir and os .path .isdir (ckpt_model_dir ):
429+ ckpt_conf_dict = dict (
430+ enabled = True ,
431+ checkpoint_dir = pretrained_model_name_or_path ,
432+ model_save_format = "safetensors" ,
433+ model_repo_id = None ,
434+ model_cache_dir = kwargs .get ("cache_dir" , TRANSFORMERS_CACHE ),
435+ save_consolidated = True ,
436+ is_peft = False ,
437+ is_async = False ,
438+ dequantize_base_checkpoint = False ,
439+ )
440+ if state_dict_keys_before_parallel :
441+ ckpt_conf_dict ["model_state_dict_keys" ] = state_dict_keys_before_parallel
442+ checkpoint_config = CheckpointingConfig (** ckpt_conf_dict )
443+
444+ dp_rank = torch .distributed .get_rank () if torch .distributed .is_initialized () else 0
445+ tp_rank = 0
446+ pp_rank = 0
447+
448+ model ._nemo_checkpointer = Checkpointer (
449+ config = checkpoint_config ,
450+ dp_rank = dp_rank ,
451+ tp_rank = tp_rank ,
452+ pp_rank = pp_rank ,
453+ moe_mesh = getattr (internal_wrapper , "moe_mesh" , None ) if 'internal_wrapper' in locals () else None ,
454+ )
455+ try :
456+ model ._nemo_checkpointer .load_model (model , ckpt_model_dir )
457+ except Exception as e :
458+ logger .warning (f"Failed to load local NeMo checkpoint from { ckpt_model_dir } : { e } " )
459+ except Exception as e :
460+ logger .warning (f"Checkpoint autodetection failed: { e } " )
461+
403462 model .config .update ({"nemo_version" : __version__ })
404463 return model
405464
@@ -518,30 +577,41 @@ def _retry(**override):
518577 return _retry (attn_implementation = "eager" )
519578 raise e
520579
521- # Initialize model wrapper from `distributed` if provided
522- if distributed is not None and model_wrapper is None :
523- if isinstance (distributed , FSDP2Manager ):
524- model_wrapper = distributed
525- elif isinstance (distributed , dict ):
526- init_kwargs = dict (distributed )
527- try :
528- import torch .distributed as dist
580+ # Build internal wrapper (FSDP2 or DDP) using device_mesh or config
581+ internal_wrapper = None
582+ if torch .distributed .is_initialized ():
583+ conf = distributed if isinstance (distributed , dict ) else {}
584+ tp = int (conf .get ("tp_size" , 1 ))
585+ cp = int (conf .get ("cp_size" , 1 ))
586+ pp = int (conf .get ("pp_size" , 1 ))
587+ dp = int (conf .get ("dp_size" , 0 ))
588+ world_size = conf .get ("world_size" , torch .distributed .get_world_size ())
589+ backend = conf .get ("backend" , "nccl" if torch .cuda .is_available () else "gloo" )
529590
530- if "world_size" not in init_kwargs and dist .is_available () and dist .is_initialized ():
531- init_kwargs ["world_size" ] = dist .get_world_size ()
532- except Exception :
533- pass
534- if "backend" not in init_kwargs :
535- init_kwargs ["backend" ] = "nccl" if torch .cuda .is_available () else "gloo"
536- model_wrapper = FSDP2Manager (** init_kwargs )
591+ if tp > 1 or cp > 1 or pp > 1 :
592+ internal_wrapper = FSDP2Manager (
593+ dp_size = dp if dp > 0 else None ,
594+ tp_size = tp ,
595+ cp_size = cp ,
596+ pp_size = pp ,
597+ backend = backend ,
598+ world_size = world_size ,
599+ use_hf_tp_plan = bool (conf .get ("use_hf_tp_plan" , False )),
600+ sequence_parallel = bool (conf .get ("sequence_parallel" , False )),
601+ activation_checkpointing = bool (conf .get ("activation_checkpointing" , False )),
602+ )
603+ # If caller provided a prebuilt DeviceMesh, inject it
604+ if device_mesh is not None :
605+ internal_wrapper .device_mesh = device_mesh
606+ elif conf :
607+ internal_wrapper = DDPManager (
608+ backend = backend ,
609+ activation_checkpointing = bool (conf .get ("activation_checkpointing" , False )),
610+ )
537611
538612 # If distributed tensor/context parallelism is requested, avoid Liger patch like in train_ft
539- if model_wrapper is not None :
540- try :
541- if getattr (model_wrapper , "tp_size" , 1 ) > 1 or getattr (model_wrapper , "cp_size" , 1 ) > 1 :
542- use_liger_kernel = False
543- except Exception :
544- pass
613+ # (handled above when building internal_wrapper in from_pretrained; for from_config, we only rely on
614+ # distributed dict and internal_wrapper built below.)
545615
546616 # Kernel patching
547617 try :
@@ -563,14 +633,14 @@ def _retry(**override):
563633
564634 # Optional: parallelize model according to parallel_scheme decision
565635 # Priority: FSDP2 when tp_size>1 or cp_size>1; else DDP when dp_size>1; else single GPU
566- if model_wrapper is not None :
636+ if internal_wrapper is not None :
567637 if torch .distributed .is_initialized ():
568638 try :
569- model = model_wrapper .parallelize (model )
639+ model = internal_wrapper .parallelize (model )
570640 except Exception as e :
571- logger .warning ("model_wrapper. parallelize failed: %s" , e )
641+ logger .warning ("internal parallelize failed: %s" , e )
572642 else :
573- logger .warning ("model_wrapper provided but torch.distributed is not initialized; skipping parallelize" )
643+ logger .warning ("distributed requested but torch.distributed is not initialized; skipping parallelize" )
574644 elif distributed is not None and torch .distributed .is_initialized ():
575645 conf = distributed if isinstance (distributed , dict ) else {}
576646 tp = int (conf .get ("tp_size" , 1 ))
0 commit comments