Skip to content

Commit c375b42

Browse files
committed
step
Signed-off-by: Alexandros Koumparoulis <[email protected]>
1 parent de26acd commit c375b42

File tree

1 file changed

+103
-33
lines changed

1 file changed

+103
-33
lines changed

nemo_automodel/_transformers/auto_model.py

Lines changed: 103 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import functools
16+
import os
1617
import gc
1718
import inspect
1819
import logging
@@ -33,6 +34,7 @@
3334
)
3435
from transformers.modeling_utils import _get_resolved_checkpoint_files
3536
from transformers.models.auto.auto_factory import _BaseAutoModelClass
37+
from transformers.utils.hub import TRANSFORMERS_CACHE
3638

3739
from nemo_automodel import __version__
3840
from nemo_automodel._transformers.registry import ModelRegistry
@@ -41,6 +43,7 @@
4143
from nemo_automodel.components.distributed.fsdp2 import FSDP2Manager
4244
from nemo_automodel.components.distributed.ddp import DDPManager
4345
from torch.distributed.device_mesh import DeviceMesh
46+
from nemo_automodel.components.checkpoint.checkpointing import Checkpointer, CheckpointingConfig
4447

4548
HAS_LIGER_KERNEL, liger_kernel_trf = safe_import("liger_kernel.transformers")
4649
logger = logging.getLogger(__name__)
@@ -338,19 +341,21 @@ def _retry(**override):
338341
if dist.is_initialized():
339342
dist.barrier()
340343
logger.info(f"Using custom model implementation for {config.architectures[0]}")
341-
return model
344+
else:
345+
model = None
342346
except Exception as e:
343347
logger.error(f"Failed to use custom model implementation with error: {e}")
344348

345349
if quantization_config is not None:
346350
kwargs["quantization_config"] = quantization_config
347-
model = super().from_pretrained(
348-
pretrained_model_name_or_path,
349-
*model_args,
350-
torch_dtype=torch_dtype,
351-
attn_implementation=attn_implementation,
352-
**kwargs,
353-
)
351+
if model is None:
352+
model = super().from_pretrained(
353+
pretrained_model_name_or_path,
354+
*model_args,
355+
torch_dtype=torch_dtype,
356+
attn_implementation=attn_implementation,
357+
**kwargs,
358+
)
354359
cls.__name__ = name
355360
except ValueError as e:
356361
if "does not support" in str(e):
@@ -361,6 +366,12 @@ def _retry(**override):
361366
return _retry(attn_implementation=attn_implementation)
362367
raise e
363368

369+
# Copy state dict keys before any parallelization (used by checkpointing)
370+
try:
371+
state_dict_keys_before_parallel = list(model.state_dict().keys())
372+
except Exception:
373+
state_dict_keys_before_parallel = []
374+
364375
# Decide kernel patching based on requested parallelism
365376
if distributed is not None:
366377
conf = distributed if isinstance(distributed, dict) else {}
@@ -369,6 +380,14 @@ def _retry(**override):
369380
if tp > 1 or cp > 1:
370381
use_liger_kernel = False
371382

383+
# Guard for CP requiring SDPA support on some models (moved from train_ft)
384+
try:
385+
if cp > 1 and hasattr(model, "_supports_sdpa") and model._supports_sdpa is False:
386+
raise ValueError("Model does not support SDPA required for context parallelism")
387+
except Exception:
388+
# If attribute missing, do not block
389+
pass
390+
372391
# Build internal wrapper (FSDP2 or DDP) using device_mesh or config
373392
internal_wrapper = None
374393
if torch.distributed.is_initialized():
@@ -436,6 +455,46 @@ def _retry(**override):
436455
dev = _choose_device(device)
437456
_move_module_to_device(model, dev, torch_dtype)
438457

458+
# Auto-detect local NeMo checkpoint path and load via Checkpointer
459+
try:
460+
is_local_path = isinstance(pretrained_model_name_or_path, str) and os.path.isdir(pretrained_model_name_or_path)
461+
ckpt_model_dir = (
462+
os.path.join(pretrained_model_name_or_path, "model") if is_local_path else None
463+
)
464+
if ckpt_model_dir and os.path.isdir(ckpt_model_dir):
465+
ckpt_conf_dict = dict(
466+
enabled=True,
467+
checkpoint_dir=pretrained_model_name_or_path,
468+
model_save_format="safetensors",
469+
model_repo_id=None,
470+
model_cache_dir=kwargs.get("cache_dir", TRANSFORMERS_CACHE),
471+
save_consolidated=True,
472+
is_peft=False,
473+
is_async=False,
474+
dequantize_base_checkpoint=False,
475+
)
476+
if state_dict_keys_before_parallel:
477+
ckpt_conf_dict["model_state_dict_keys"] = state_dict_keys_before_parallel
478+
checkpoint_config = CheckpointingConfig(**ckpt_conf_dict)
479+
480+
dp_rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
481+
tp_rank = 0
482+
pp_rank = 0
483+
484+
model._nemo_checkpointer = Checkpointer(
485+
config=checkpoint_config,
486+
dp_rank=dp_rank,
487+
tp_rank=tp_rank,
488+
pp_rank=pp_rank,
489+
moe_mesh=getattr(internal_wrapper, "moe_mesh", None) if 'internal_wrapper' in locals() else None,
490+
)
491+
try:
492+
model._nemo_checkpointer.load_model(model, ckpt_model_dir)
493+
except Exception as e:
494+
logger.warning(f"Failed to load local NeMo checkpoint from {ckpt_model_dir}: {e}")
495+
except Exception as e:
496+
logger.warning(f"Checkpoint autodetection failed: {e}")
497+
439498
model.config.update({"nemo_version": __version__})
440499
return model
441500

@@ -554,30 +613,41 @@ def _retry(**override):
554613
return _retry(attn_implementation="eager")
555614
raise e
556615

557-
# Initialize model wrapper from `distributed` if provided
558-
if distributed is not None and model_wrapper is None:
559-
if isinstance(distributed, FSDP2Manager):
560-
model_wrapper = distributed
561-
elif isinstance(distributed, dict):
562-
init_kwargs = dict(distributed)
563-
try:
564-
import torch.distributed as dist
616+
# Build internal wrapper (FSDP2 or DDP) using device_mesh or config
617+
internal_wrapper = None
618+
if torch.distributed.is_initialized():
619+
conf = distributed if isinstance(distributed, dict) else {}
620+
tp = int(conf.get("tp_size", 1))
621+
cp = int(conf.get("cp_size", 1))
622+
pp = int(conf.get("pp_size", 1))
623+
dp = int(conf.get("dp_size", 0))
624+
world_size = conf.get("world_size", torch.distributed.get_world_size())
625+
backend = conf.get("backend", "nccl" if torch.cuda.is_available() else "gloo")
565626

566-
if "world_size" not in init_kwargs and dist.is_available() and dist.is_initialized():
567-
init_kwargs["world_size"] = dist.get_world_size()
568-
except Exception:
569-
pass
570-
if "backend" not in init_kwargs:
571-
init_kwargs["backend"] = "nccl" if torch.cuda.is_available() else "gloo"
572-
model_wrapper = FSDP2Manager(**init_kwargs)
627+
if tp > 1 or cp > 1 or pp > 1:
628+
internal_wrapper = FSDP2Manager(
629+
dp_size=dp if dp > 0 else None,
630+
tp_size=tp,
631+
cp_size=cp,
632+
pp_size=pp,
633+
backend=backend,
634+
world_size=world_size,
635+
use_hf_tp_plan=bool(conf.get("use_hf_tp_plan", False)),
636+
sequence_parallel=bool(conf.get("sequence_parallel", False)),
637+
activation_checkpointing=bool(conf.get("activation_checkpointing", False)),
638+
)
639+
# If caller provided a prebuilt DeviceMesh, inject it
640+
if device_mesh is not None:
641+
internal_wrapper.device_mesh = device_mesh
642+
elif conf:
643+
internal_wrapper = DDPManager(
644+
backend=backend,
645+
activation_checkpointing=bool(conf.get("activation_checkpointing", False)),
646+
)
573647

574648
# If distributed tensor/context parallelism is requested, avoid Liger patch like in train_ft
575-
if model_wrapper is not None:
576-
try:
577-
if getattr(model_wrapper, "tp_size", 1) > 1 or getattr(model_wrapper, "cp_size", 1) > 1:
578-
use_liger_kernel = False
579-
except Exception:
580-
pass
649+
# (handled above when building internal_wrapper in from_pretrained; for from_config, we only rely on
650+
# distributed dict and internal_wrapper built below.)
581651

582652
# Kernel patching
583653
try:
@@ -599,14 +669,14 @@ def _retry(**override):
599669

600670
# Optional: parallelize model according to parallel_scheme decision
601671
# Priority: FSDP2 when tp_size>1 or cp_size>1; else DDP when dp_size>1; else single GPU
602-
if model_wrapper is not None:
672+
if internal_wrapper is not None:
603673
if torch.distributed.is_initialized():
604674
try:
605-
model = model_wrapper.parallelize(model)
675+
model = internal_wrapper.parallelize(model)
606676
except Exception as e:
607-
logger.warning("model_wrapper.parallelize failed: %s", e)
677+
logger.warning("internal parallelize failed: %s", e)
608678
else:
609-
logger.warning("model_wrapper provided but torch.distributed is not initialized; skipping parallelize")
679+
logger.warning("distributed requested but torch.distributed is not initialized; skipping parallelize")
610680
elif distributed is not None and torch.distributed.is_initialized():
611681
conf = distributed if isinstance(distributed, dict) else {}
612682
tp = int(conf.get("tp_size", 1))

0 commit comments

Comments
 (0)