Skip to content

Commit 32c3357

Browse files
committed
step
Signed-off-by: Alexandros Koumparoulis <[email protected]>
1 parent 31f29b8 commit 32c3357

File tree

1 file changed

+103
-33
lines changed

1 file changed

+103
-33
lines changed

nemo_automodel/_transformers/auto_model.py

Lines changed: 103 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import functools
16+
import os
1617
import gc
1718
import inspect
1819
import logging
@@ -30,6 +31,7 @@
3031
PreTrainedModel,
3132
)
3233
from transformers.models.auto.auto_factory import _BaseAutoModelClass
34+
from transformers.utils.hub import TRANSFORMERS_CACHE
3335

3436
from nemo_automodel import __version__
3537
from nemo_automodel._transformers.registry import ModelRegistry
@@ -38,6 +40,7 @@
3840
from nemo_automodel.components.distributed.fsdp2 import FSDP2Manager
3941
from nemo_automodel.components.distributed.ddp import DDPManager
4042
from torch.distributed.device_mesh import DeviceMesh
43+
from nemo_automodel.components.checkpoint.checkpointing import Checkpointer, CheckpointingConfig
4144

4245
HAS_LIGER_KERNEL, liger_kernel_trf = safe_import("liger_kernel.transformers")
4346
logger = logging.getLogger(__name__)
@@ -302,19 +305,21 @@ def _retry(**override):
302305
config, *model_args, **kwargs
303306
)
304307
logger.info(f"Using custom model implementation for {config.architectures[0]}")
305-
return model
308+
else:
309+
model = None
306310
except Exception as e:
307311
logger.error(f"Failed to use custom model implementation with error: {e}")
308312

309313
if quantization_config is not None:
310314
kwargs["quantization_config"] = quantization_config
311-
model = super().from_pretrained(
312-
pretrained_model_name_or_path,
313-
*model_args,
314-
torch_dtype=torch_dtype,
315-
attn_implementation=attn_implementation,
316-
**kwargs,
317-
)
315+
if model is None:
316+
model = super().from_pretrained(
317+
pretrained_model_name_or_path,
318+
*model_args,
319+
torch_dtype=torch_dtype,
320+
attn_implementation=attn_implementation,
321+
**kwargs,
322+
)
318323
cls.__name__ = name
319324
except ValueError as e:
320325
if "does not support" in str(e):
@@ -325,6 +330,12 @@ def _retry(**override):
325330
return _retry(attn_implementation=attn_implementation)
326331
raise e
327332

333+
# Copy state dict keys before any parallelization (used by checkpointing)
334+
try:
335+
state_dict_keys_before_parallel = list(model.state_dict().keys())
336+
except Exception:
337+
state_dict_keys_before_parallel = []
338+
328339
# Decide kernel patching based on requested parallelism
329340
if distributed is not None:
330341
conf = distributed if isinstance(distributed, dict) else {}
@@ -333,6 +344,14 @@ def _retry(**override):
333344
if tp > 1 or cp > 1:
334345
use_liger_kernel = False
335346

347+
# Guard for CP requiring SDPA support on some models (moved from train_ft)
348+
try:
349+
if cp > 1 and hasattr(model, "_supports_sdpa") and model._supports_sdpa is False:
350+
raise ValueError("Model does not support SDPA required for context parallelism")
351+
except Exception:
352+
# If attribute missing, do not block
353+
pass
354+
336355
# Build internal wrapper (FSDP2 or DDP) using device_mesh or config
337356
internal_wrapper = None
338357
if torch.distributed.is_initialized():
@@ -400,6 +419,46 @@ def _retry(**override):
400419
dev = _choose_device(device)
401420
_move_module_to_device(model, dev, torch_dtype)
402421

422+
# Auto-detect local NeMo checkpoint path and load via Checkpointer
423+
try:
424+
is_local_path = isinstance(pretrained_model_name_or_path, str) and os.path.isdir(pretrained_model_name_or_path)
425+
ckpt_model_dir = (
426+
os.path.join(pretrained_model_name_or_path, "model") if is_local_path else None
427+
)
428+
if ckpt_model_dir and os.path.isdir(ckpt_model_dir):
429+
ckpt_conf_dict = dict(
430+
enabled=True,
431+
checkpoint_dir=pretrained_model_name_or_path,
432+
model_save_format="safetensors",
433+
model_repo_id=None,
434+
model_cache_dir=kwargs.get("cache_dir", TRANSFORMERS_CACHE),
435+
save_consolidated=True,
436+
is_peft=False,
437+
is_async=False,
438+
dequantize_base_checkpoint=False,
439+
)
440+
if state_dict_keys_before_parallel:
441+
ckpt_conf_dict["model_state_dict_keys"] = state_dict_keys_before_parallel
442+
checkpoint_config = CheckpointingConfig(**ckpt_conf_dict)
443+
444+
dp_rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
445+
tp_rank = 0
446+
pp_rank = 0
447+
448+
model._nemo_checkpointer = Checkpointer(
449+
config=checkpoint_config,
450+
dp_rank=dp_rank,
451+
tp_rank=tp_rank,
452+
pp_rank=pp_rank,
453+
moe_mesh=getattr(internal_wrapper, "moe_mesh", None) if 'internal_wrapper' in locals() else None,
454+
)
455+
try:
456+
model._nemo_checkpointer.load_model(model, ckpt_model_dir)
457+
except Exception as e:
458+
logger.warning(f"Failed to load local NeMo checkpoint from {ckpt_model_dir}: {e}")
459+
except Exception as e:
460+
logger.warning(f"Checkpoint autodetection failed: {e}")
461+
403462
model.config.update({"nemo_version": __version__})
404463
return model
405464

@@ -518,30 +577,41 @@ def _retry(**override):
518577
return _retry(attn_implementation="eager")
519578
raise e
520579

521-
# Initialize model wrapper from `distributed` if provided
522-
if distributed is not None and model_wrapper is None:
523-
if isinstance(distributed, FSDP2Manager):
524-
model_wrapper = distributed
525-
elif isinstance(distributed, dict):
526-
init_kwargs = dict(distributed)
527-
try:
528-
import torch.distributed as dist
580+
# Build internal wrapper (FSDP2 or DDP) using device_mesh or config
581+
internal_wrapper = None
582+
if torch.distributed.is_initialized():
583+
conf = distributed if isinstance(distributed, dict) else {}
584+
tp = int(conf.get("tp_size", 1))
585+
cp = int(conf.get("cp_size", 1))
586+
pp = int(conf.get("pp_size", 1))
587+
dp = int(conf.get("dp_size", 0))
588+
world_size = conf.get("world_size", torch.distributed.get_world_size())
589+
backend = conf.get("backend", "nccl" if torch.cuda.is_available() else "gloo")
529590

530-
if "world_size" not in init_kwargs and dist.is_available() and dist.is_initialized():
531-
init_kwargs["world_size"] = dist.get_world_size()
532-
except Exception:
533-
pass
534-
if "backend" not in init_kwargs:
535-
init_kwargs["backend"] = "nccl" if torch.cuda.is_available() else "gloo"
536-
model_wrapper = FSDP2Manager(**init_kwargs)
591+
if tp > 1 or cp > 1 or pp > 1:
592+
internal_wrapper = FSDP2Manager(
593+
dp_size=dp if dp > 0 else None,
594+
tp_size=tp,
595+
cp_size=cp,
596+
pp_size=pp,
597+
backend=backend,
598+
world_size=world_size,
599+
use_hf_tp_plan=bool(conf.get("use_hf_tp_plan", False)),
600+
sequence_parallel=bool(conf.get("sequence_parallel", False)),
601+
activation_checkpointing=bool(conf.get("activation_checkpointing", False)),
602+
)
603+
# If caller provided a prebuilt DeviceMesh, inject it
604+
if device_mesh is not None:
605+
internal_wrapper.device_mesh = device_mesh
606+
elif conf:
607+
internal_wrapper = DDPManager(
608+
backend=backend,
609+
activation_checkpointing=bool(conf.get("activation_checkpointing", False)),
610+
)
537611

538612
# If distributed tensor/context parallelism is requested, avoid Liger patch like in train_ft
539-
if model_wrapper is not None:
540-
try:
541-
if getattr(model_wrapper, "tp_size", 1) > 1 or getattr(model_wrapper, "cp_size", 1) > 1:
542-
use_liger_kernel = False
543-
except Exception:
544-
pass
613+
# (handled above when building internal_wrapper in from_pretrained; for from_config, we only rely on
614+
# distributed dict and internal_wrapper built below.)
545615

546616
# Kernel patching
547617
try:
@@ -563,14 +633,14 @@ def _retry(**override):
563633

564634
# Optional: parallelize model according to parallel_scheme decision
565635
# Priority: FSDP2 when tp_size>1 or cp_size>1; else DDP when dp_size>1; else single GPU
566-
if model_wrapper is not None:
636+
if internal_wrapper is not None:
567637
if torch.distributed.is_initialized():
568638
try:
569-
model = model_wrapper.parallelize(model)
639+
model = internal_wrapper.parallelize(model)
570640
except Exception as e:
571-
logger.warning("model_wrapper.parallelize failed: %s", e)
641+
logger.warning("internal parallelize failed: %s", e)
572642
else:
573-
logger.warning("model_wrapper provided but torch.distributed is not initialized; skipping parallelize")
643+
logger.warning("distributed requested but torch.distributed is not initialized; skipping parallelize")
574644
elif distributed is not None and torch.distributed.is_initialized():
575645
conf = distributed if isinstance(distributed, dict) else {}
576646
tp = int(conf.get("tp_size", 1))

0 commit comments

Comments
 (0)