1313# limitations under the License.
1414
1515import functools
16+ import os
1617import gc
1718import inspect
1819import logging
3334)
3435from transformers .modeling_utils import _get_resolved_checkpoint_files
3536from transformers .models .auto .auto_factory import _BaseAutoModelClass
37+ from transformers .utils .hub import TRANSFORMERS_CACHE
3638
3739from nemo_automodel import __version__
3840from nemo_automodel ._transformers .registry import ModelRegistry
4143from nemo_automodel .components .distributed .fsdp2 import FSDP2Manager
4244from nemo_automodel .components .distributed .ddp import DDPManager
4345from torch .distributed .device_mesh import DeviceMesh
46+ from nemo_automodel .components .checkpoint .checkpointing import Checkpointer , CheckpointingConfig
4447
4548HAS_LIGER_KERNEL , liger_kernel_trf = safe_import ("liger_kernel.transformers" )
4649logger = logging .getLogger (__name__ )
@@ -338,19 +341,21 @@ def _retry(**override):
338341 if dist .is_initialized ():
339342 dist .barrier ()
340343 logger .info (f"Using custom model implementation for { config .architectures [0 ]} " )
341- return model
344+ else :
345+ model = None
342346 except Exception as e :
343347 logger .error (f"Failed to use custom model implementation with error: { e } " )
344348
345349 if quantization_config is not None :
346350 kwargs ["quantization_config" ] = quantization_config
347- model = super ().from_pretrained (
348- pretrained_model_name_or_path ,
349- * model_args ,
350- torch_dtype = torch_dtype ,
351- attn_implementation = attn_implementation ,
352- ** kwargs ,
353- )
351+ if model is None :
352+ model = super ().from_pretrained (
353+ pretrained_model_name_or_path ,
354+ * model_args ,
355+ torch_dtype = torch_dtype ,
356+ attn_implementation = attn_implementation ,
357+ ** kwargs ,
358+ )
354359 cls .__name__ = name
355360 except ValueError as e :
356361 if "does not support" in str (e ):
@@ -361,6 +366,12 @@ def _retry(**override):
361366 return _retry (attn_implementation = attn_implementation )
362367 raise e
363368
369+ # Copy state dict keys before any parallelization (used by checkpointing)
370+ try :
371+ state_dict_keys_before_parallel = list (model .state_dict ().keys ())
372+ except Exception :
373+ state_dict_keys_before_parallel = []
374+
364375 # Decide kernel patching based on requested parallelism
365376 if distributed is not None :
366377 conf = distributed if isinstance (distributed , dict ) else {}
@@ -369,6 +380,14 @@ def _retry(**override):
369380 if tp > 1 or cp > 1 :
370381 use_liger_kernel = False
371382
383+ # Guard for CP requiring SDPA support on some models (moved from train_ft)
384+ try :
385+ if cp > 1 and hasattr (model , "_supports_sdpa" ) and model ._supports_sdpa is False :
386+ raise ValueError ("Model does not support SDPA required for context parallelism" )
387+ except Exception :
388+ # If attribute missing, do not block
389+ pass
390+
372391 # Build internal wrapper (FSDP2 or DDP) using device_mesh or config
373392 internal_wrapper = None
374393 if torch .distributed .is_initialized ():
@@ -436,6 +455,46 @@ def _retry(**override):
436455 dev = _choose_device (device )
437456 _move_module_to_device (model , dev , torch_dtype )
438457
458+ # Auto-detect local NeMo checkpoint path and load via Checkpointer
459+ try :
460+ is_local_path = isinstance (pretrained_model_name_or_path , str ) and os .path .isdir (pretrained_model_name_or_path )
461+ ckpt_model_dir = (
462+ os .path .join (pretrained_model_name_or_path , "model" ) if is_local_path else None
463+ )
464+ if ckpt_model_dir and os .path .isdir (ckpt_model_dir ):
465+ ckpt_conf_dict = dict (
466+ enabled = True ,
467+ checkpoint_dir = pretrained_model_name_or_path ,
468+ model_save_format = "safetensors" ,
469+ model_repo_id = None ,
470+ model_cache_dir = kwargs .get ("cache_dir" , TRANSFORMERS_CACHE ),
471+ save_consolidated = True ,
472+ is_peft = False ,
473+ is_async = False ,
474+ dequantize_base_checkpoint = False ,
475+ )
476+ if state_dict_keys_before_parallel :
477+ ckpt_conf_dict ["model_state_dict_keys" ] = state_dict_keys_before_parallel
478+ checkpoint_config = CheckpointingConfig (** ckpt_conf_dict )
479+
480+ dp_rank = torch .distributed .get_rank () if torch .distributed .is_initialized () else 0
481+ tp_rank = 0
482+ pp_rank = 0
483+
484+ model ._nemo_checkpointer = Checkpointer (
485+ config = checkpoint_config ,
486+ dp_rank = dp_rank ,
487+ tp_rank = tp_rank ,
488+ pp_rank = pp_rank ,
489+ moe_mesh = getattr (internal_wrapper , "moe_mesh" , None ) if 'internal_wrapper' in locals () else None ,
490+ )
491+ try :
492+ model ._nemo_checkpointer .load_model (model , ckpt_model_dir )
493+ except Exception as e :
494+ logger .warning (f"Failed to load local NeMo checkpoint from { ckpt_model_dir } : { e } " )
495+ except Exception as e :
496+ logger .warning (f"Checkpoint autodetection failed: { e } " )
497+
439498 model .config .update ({"nemo_version" : __version__ })
440499 return model
441500
@@ -554,30 +613,41 @@ def _retry(**override):
554613 return _retry (attn_implementation = "eager" )
555614 raise e
556615
557- # Initialize model wrapper from `distributed` if provided
558- if distributed is not None and model_wrapper is None :
559- if isinstance (distributed , FSDP2Manager ):
560- model_wrapper = distributed
561- elif isinstance (distributed , dict ):
562- init_kwargs = dict (distributed )
563- try :
564- import torch .distributed as dist
616+ # Build internal wrapper (FSDP2 or DDP) using device_mesh or config
617+ internal_wrapper = None
618+ if torch .distributed .is_initialized ():
619+ conf = distributed if isinstance (distributed , dict ) else {}
620+ tp = int (conf .get ("tp_size" , 1 ))
621+ cp = int (conf .get ("cp_size" , 1 ))
622+ pp = int (conf .get ("pp_size" , 1 ))
623+ dp = int (conf .get ("dp_size" , 0 ))
624+ world_size = conf .get ("world_size" , torch .distributed .get_world_size ())
625+ backend = conf .get ("backend" , "nccl" if torch .cuda .is_available () else "gloo" )
565626
566- if "world_size" not in init_kwargs and dist .is_available () and dist .is_initialized ():
567- init_kwargs ["world_size" ] = dist .get_world_size ()
568- except Exception :
569- pass
570- if "backend" not in init_kwargs :
571- init_kwargs ["backend" ] = "nccl" if torch .cuda .is_available () else "gloo"
572- model_wrapper = FSDP2Manager (** init_kwargs )
627+ if tp > 1 or cp > 1 or pp > 1 :
628+ internal_wrapper = FSDP2Manager (
629+ dp_size = dp if dp > 0 else None ,
630+ tp_size = tp ,
631+ cp_size = cp ,
632+ pp_size = pp ,
633+ backend = backend ,
634+ world_size = world_size ,
635+ use_hf_tp_plan = bool (conf .get ("use_hf_tp_plan" , False )),
636+ sequence_parallel = bool (conf .get ("sequence_parallel" , False )),
637+ activation_checkpointing = bool (conf .get ("activation_checkpointing" , False )),
638+ )
639+ # If caller provided a prebuilt DeviceMesh, inject it
640+ if device_mesh is not None :
641+ internal_wrapper .device_mesh = device_mesh
642+ elif conf :
643+ internal_wrapper = DDPManager (
644+ backend = backend ,
645+ activation_checkpointing = bool (conf .get ("activation_checkpointing" , False )),
646+ )
573647
574648 # If distributed tensor/context parallelism is requested, avoid Liger patch like in train_ft
575- if model_wrapper is not None :
576- try :
577- if getattr (model_wrapper , "tp_size" , 1 ) > 1 or getattr (model_wrapper , "cp_size" , 1 ) > 1 :
578- use_liger_kernel = False
579- except Exception :
580- pass
649+ # (handled above when building internal_wrapper in from_pretrained; for from_config, we only rely on
650+ # distributed dict and internal_wrapper built below.)
581651
582652 # Kernel patching
583653 try :
@@ -599,14 +669,14 @@ def _retry(**override):
599669
600670 # Optional: parallelize model according to parallel_scheme decision
601671 # Priority: FSDP2 when tp_size>1 or cp_size>1; else DDP when dp_size>1; else single GPU
602- if model_wrapper is not None :
672+ if internal_wrapper is not None :
603673 if torch .distributed .is_initialized ():
604674 try :
605- model = model_wrapper .parallelize (model )
675+ model = internal_wrapper .parallelize (model )
606676 except Exception as e :
607- logger .warning ("model_wrapper. parallelize failed: %s" , e )
677+ logger .warning ("internal parallelize failed: %s" , e )
608678 else :
609- logger .warning ("model_wrapper provided but torch.distributed is not initialized; skipping parallelize" )
679+ logger .warning ("distributed requested but torch.distributed is not initialized; skipping parallelize" )
610680 elif distributed is not None and torch .distributed .is_initialized ():
611681 conf = distributed if isinstance (distributed , dict ) else {}
612682 tp = int (conf .get ("tp_size" , 1 ))
0 commit comments