Skip to content

Commit 31f29b8

Browse files
committed
wip
Signed-off-by: Alexandros Koumparoulis <[email protected]>
1 parent 3e6952d commit 31f29b8

File tree

2 files changed

+201
-3
lines changed

2 files changed

+201
-3
lines changed

nemo_automodel/_transformers/auto_model.py

Lines changed: 173 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import inspect
1818
import logging
1919
import types
20-
from typing import List, Optional, Union
20+
from typing import Any, Dict, List, Optional, Union
2121

2222
import torch
2323
from torch.nn.attention import SDPBackend, sdpa_kernel
@@ -35,6 +35,9 @@
3535
from nemo_automodel._transformers.registry import ModelRegistry
3636
from nemo_automodel.shared.import_utils import safe_import
3737
from nemo_automodel.shared.utils import dtype_from_str
38+
from nemo_automodel.components.distributed.fsdp2 import FSDP2Manager
39+
from nemo_automodel.components.distributed.ddp import DDPManager
40+
from torch.distributed.device_mesh import DeviceMesh
3841

3942
HAS_LIGER_KERNEL, liger_kernel_trf = safe_import("liger_kernel.transformers")
4043
logger = logging.getLogger(__name__)
@@ -148,6 +151,29 @@ def _get_next_fallback_attn(attn_implementation: str) -> str:
148151
return priorities[0]
149152

150153

154+
def _choose_device(device: Optional[torch.device]) -> torch.device:
155+
if device is not None:
156+
return device
157+
if torch.cuda.is_available():
158+
import os
159+
160+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
161+
return torch.device("cuda", local_rank)
162+
return torch.device("cpu")
163+
164+
165+
def _move_module_to_device(module: torch.nn.Module, device: torch.device, torch_dtype: Any) -> None:
166+
# torch_dtype can be "auto", torch.dtype, or string
167+
if torch_dtype == "auto":
168+
dtype = None
169+
else:
170+
dtype = dtype_from_str(torch_dtype) if isinstance(torch_dtype, str) else torch_dtype
171+
if dtype is not None:
172+
module.to(device=device, dtype=dtype)
173+
else:
174+
module.to(device=device)
175+
176+
151177
class _BaseNeMoAutoModelClass(_BaseAutoModelClass):
152178
"""
153179
Drop-in replacement for ``_BaseAutoModelClass`` that includes custom-kernels.
@@ -182,6 +208,10 @@ def from_pretrained(
182208
attn_implementation: str = "flash_attention_2",
183209
quantization_config=None,
184210
force_hf: bool = False,
211+
device_mesh: Optional[DeviceMesh] = None,
212+
distributed: Optional[Dict[str, Any]] = None,
213+
device: Optional[torch.device] = None,
214+
move_to_device: bool = False,
185215
**kwargs,
186216
) -> PreTrainedModel:
187217
"""
@@ -248,6 +278,10 @@ def _retry(**override):
248278
use_liger_kernel=override.get("use_liger_kernel", use_liger_kernel),
249279
use_sdpa_patching=override.get("use_sdpa_patching", use_sdpa_patching),
250280
sdpa_method=sdpa_method,
281+
device_mesh=device_mesh,
282+
distributed=distributed,
283+
device=device,
284+
move_to_device=move_to_device,
251285
**kwargs,
252286
)
253287

@@ -291,6 +325,47 @@ def _retry(**override):
291325
return _retry(attn_implementation=attn_implementation)
292326
raise e
293327

328+
# Decide kernel patching based on requested parallelism
329+
if distributed is not None:
330+
conf = distributed if isinstance(distributed, dict) else {}
331+
tp = int(conf.get("tp_size", 1))
332+
cp = int(conf.get("cp_size", 1))
333+
if tp > 1 or cp > 1:
334+
use_liger_kernel = False
335+
336+
# Build internal wrapper (FSDP2 or DDP) using device_mesh or config
337+
internal_wrapper = None
338+
if torch.distributed.is_initialized():
339+
conf = distributed if isinstance(distributed, dict) else {}
340+
tp = int(conf.get("tp_size", 1))
341+
cp = int(conf.get("cp_size", 1))
342+
pp = int(conf.get("pp_size", 1))
343+
dp = int(conf.get("dp_size", 0))
344+
world_size = conf.get("world_size", torch.distributed.get_world_size())
345+
backend = conf.get("backend", "nccl" if torch.cuda.is_available() else "gloo")
346+
347+
if tp > 1 or cp > 1 or pp > 1:
348+
internal_wrapper = FSDP2Manager(
349+
dp_size=dp if dp > 0 else None,
350+
tp_size=tp,
351+
cp_size=cp,
352+
pp_size=pp,
353+
backend=backend,
354+
world_size=world_size,
355+
use_hf_tp_plan=bool(conf.get("use_hf_tp_plan", False)),
356+
sequence_parallel=bool(conf.get("sequence_parallel", False)),
357+
activation_checkpointing=bool(conf.get("activation_checkpointing", False)),
358+
)
359+
# If caller provided a prebuilt DeviceMesh, inject it
360+
if device_mesh is not None:
361+
internal_wrapper.device_mesh = device_mesh
362+
elif conf:
363+
internal_wrapper = DDPManager(
364+
backend=backend,
365+
activation_checkpointing=bool(conf.get("activation_checkpointing", False)),
366+
)
367+
# no-op: model_wrapper removed; liger disabled via conf above
368+
294369
# Kernel patching
295370
try:
296371
if use_liger_kernel:
@@ -309,6 +384,22 @@ def _retry(**override):
309384
logging.warning("Retrying without SDPA patching.")
310385
return _retry(use_sdpa_patching=False)
311386

387+
# Optional: parallelize model according to parallel_scheme decision
388+
# Priority: FSDP2 when tp_size>1 or cp_size>1; else DDP when dp_size>1; else single GPU
389+
if internal_wrapper is not None:
390+
if torch.distributed.is_initialized():
391+
try:
392+
model = internal_wrapper.parallelize(model)
393+
except Exception as e:
394+
logger.warning("internal parallelize failed: %s", e)
395+
else:
396+
logger.warning("distributed requested but torch.distributed is not initialized; skipping parallelize")
397+
398+
# Finally move the model to the chosen device (post-parallelization), if requested
399+
if move_to_device:
400+
dev = _choose_device(device)
401+
_move_module_to_device(model, dev, torch_dtype)
402+
312403
model.config.update({"nemo_version": __version__})
313404
return model
314405

@@ -324,6 +415,10 @@ def from_config(
324415
attn_implementation: str = "flash_attention_2",
325416
quantization_config=None,
326417
force_hf: bool = False,
418+
device_mesh: Optional[DeviceMesh] = None,
419+
distributed: Optional[Dict[str, Any]] = None,
420+
device: Optional[torch.device] = None,
421+
move_to_device: bool = False,
327422
**kwargs,
328423
) -> PreTrainedModel:
329424
"""
@@ -382,6 +477,10 @@ def _retry(**override):
382477
use_liger_kernel=override.get("use_liger_kernel", use_liger_kernel),
383478
use_sdpa_patching=override.get("use_sdpa_patching", use_sdpa_patching),
384479
sdpa_method=sdpa_method,
480+
device_mesh=device_mesh,
481+
distributed=distributed,
482+
device=device,
483+
move_to_device=move_to_device,
385484
**kwargs,
386485
)
387486

@@ -419,6 +518,31 @@ def _retry(**override):
419518
return _retry(attn_implementation="eager")
420519
raise e
421520

521+
# Initialize model wrapper from `distributed` if provided
522+
if distributed is not None and model_wrapper is None:
523+
if isinstance(distributed, FSDP2Manager):
524+
model_wrapper = distributed
525+
elif isinstance(distributed, dict):
526+
init_kwargs = dict(distributed)
527+
try:
528+
import torch.distributed as dist
529+
530+
if "world_size" not in init_kwargs and dist.is_available() and dist.is_initialized():
531+
init_kwargs["world_size"] = dist.get_world_size()
532+
except Exception:
533+
pass
534+
if "backend" not in init_kwargs:
535+
init_kwargs["backend"] = "nccl" if torch.cuda.is_available() else "gloo"
536+
model_wrapper = FSDP2Manager(**init_kwargs)
537+
538+
# If distributed tensor/context parallelism is requested, avoid Liger patch like in train_ft
539+
if model_wrapper is not None:
540+
try:
541+
if getattr(model_wrapper, "tp_size", 1) > 1 or getattr(model_wrapper, "cp_size", 1) > 1:
542+
use_liger_kernel = False
543+
except Exception:
544+
pass
545+
422546
# Kernel patching
423547
try:
424548
if use_liger_kernel:
@@ -437,6 +561,54 @@ def _retry(**override):
437561
logging.warning("Retrying without SDPA patching.")
438562
return _retry(use_sdpa_patching=False)
439563

564+
# Optional: parallelize model according to parallel_scheme decision
565+
# Priority: FSDP2 when tp_size>1 or cp_size>1; else DDP when dp_size>1; else single GPU
566+
if model_wrapper is not None:
567+
if torch.distributed.is_initialized():
568+
try:
569+
model = model_wrapper.parallelize(model)
570+
except Exception as e:
571+
logger.warning("model_wrapper.parallelize failed: %s", e)
572+
else:
573+
logger.warning("model_wrapper provided but torch.distributed is not initialized; skipping parallelize")
574+
elif distributed is not None and torch.distributed.is_initialized():
575+
conf = distributed if isinstance(distributed, dict) else {}
576+
tp = int(conf.get("tp_size", 1))
577+
cp = int(conf.get("cp_size", 1))
578+
pp = int(conf.get("pp_size", 1))
579+
dp = int(conf.get("dp_size", 0))
580+
world_size = conf.get("world_size", torch.distributed.get_world_size())
581+
backend = conf.get("backend", "nccl" if torch.cuda.is_available() else "gloo")
582+
583+
if tp > 1 or cp > 1 or pp > 1:
584+
manager = FSDP2Manager(
585+
dp_size=dp if dp > 0 else None,
586+
tp_size=tp,
587+
cp_size=cp,
588+
pp_size=pp,
589+
backend=backend,
590+
world_size=world_size,
591+
use_hf_tp_plan=bool(conf.get("use_hf_tp_plan", False)),
592+
sequence_parallel=bool(conf.get("sequence_parallel", False)),
593+
activation_checkpointing=bool(conf.get("activation_checkpointing", False)),
594+
)
595+
try:
596+
model = manager.parallelize(model)
597+
except Exception as e:
598+
logger.warning("FSDP2Manager.parallelize failed: %s", e)
599+
else:
600+
ddp_backend = backend
601+
try:
602+
ddp_manager = DDPManager(backend=ddp_backend, activation_checkpointing=bool(conf.get("activation_checkpointing", False)))
603+
model = ddp_manager.parallelize(model)
604+
except Exception as e:
605+
logger.warning("DDPManager.parallelize failed: %s", e)
606+
607+
# Finally move the model to the chosen device (post-parallelization), if requested
608+
if move_to_device:
609+
dev = _choose_device(device)
610+
_move_module_to_device(model, dev, torch_dtype)
611+
440612
model.config.update({"nemo_version": __version__})
441613
return model
442614

nemo_automodel/recipes/llm/train_ft.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,32 @@ def build_model_and_optimizer(
201201

202202
kwargs["quantization_config"] = create_bnb_config(cfg_quantization)
203203

204+
# Optionally pass internal parallel scheme for HF models (non-pipeline) to let the model self-parallelize
205+
used_internal_parallel = False
206+
if (
207+
is_hf_model
208+
and autopipeline is None
209+
and get_world_size_safe() > 1
210+
and not isinstance(model_wrapper, MegatronFSDPManager)
211+
):
212+
internal_scheme = {
213+
"tp_size": int(tp_size) if tp_size is not None else 1,
214+
"cp_size": int(cp_size) if cp_size is not None else 1,
215+
"pp_size": 1,
216+
}
217+
try:
218+
world_size = get_world_size_safe()
219+
denom = max(1, internal_scheme["tp_size"] * internal_scheme["cp_size"])
220+
if world_size > 1 and world_size % denom == 0:
221+
internal_scheme["dp_size"] = world_size // denom
222+
except Exception:
223+
pass
224+
internal_scheme["backend"] = "nccl" if torch.cuda.is_available() else "gloo"
225+
kwargs["distributed"] = internal_scheme
226+
if getattr(model_wrapper, "device_mesh", None) is not None:
227+
kwargs["device_mesh"] = model_wrapper.device_mesh
228+
used_internal_parallel = True
229+
204230
# Instantiate the model in meta device to avoid OOM
205231
with init_ctx:
206232
if is_hf_model and (tp_size > 1 or cp_size > 1):
@@ -263,7 +289,7 @@ def build_model_and_optimizer(
263289
model = autopipeline
264290
else:
265291
load_weights = False
266-
if parallelize_fn is not None and get_world_size_safe() > 1:
292+
if not (is_hf_model and autopipeline is None and get_world_size_safe() > 1 and 'distributed' in kwargs) and parallelize_fn is not None and get_world_size_safe() > 1:
267293
parallelize_fn(
268294
model,
269295
world_mesh=model_wrapper.device_mesh,
@@ -281,7 +307,7 @@ def build_model_and_optimizer(
281307
ep_shard_axis_names=("ep_shard",),
282308
)
283309
load_weights = True
284-
elif callable(getattr(model_wrapper, "parallelize", None)):
310+
elif not (is_hf_model and autopipeline is None and get_world_size_safe() > 1 and 'distributed' in kwargs) and callable(getattr(model_wrapper, "parallelize", None)):
285311
# FSDP2 and MegatronFSDP should already be on the correct device
286312
if isinstance(model_wrapper, MegatronFSDPManager):
287313
# MegatronFSDP instantiate optimizer inside parallelize_function

0 commit comments

Comments
 (0)