Skip to content

Commit 56564c8

Browse files
committed
wip
Signed-off-by: Alexandros Koumparoulis <[email protected]>
1 parent 87d06f8 commit 56564c8

File tree

2 files changed

+201
-3
lines changed

2 files changed

+201
-3
lines changed

nemo_automodel/_transformers/auto_model.py

Lines changed: 173 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import inspect
1818
import logging
1919
import types
20-
from typing import List, Optional, Union
20+
from typing import Any, Dict, List, Optional, Union
2121

2222
import torch
2323
from torch.nn.attention import SDPBackend, sdpa_kernel
@@ -35,6 +35,9 @@
3535
from nemo_automodel._transformers.registry import ModelRegistry
3636
from nemo_automodel.shared.import_utils import safe_import
3737
from nemo_automodel.shared.utils import dtype_from_str
38+
from nemo_automodel.components.distributed.fsdp2 import FSDP2Manager
39+
from nemo_automodel.components.distributed.ddp import DDPManager
40+
from torch.distributed.device_mesh import DeviceMesh
3841

3942
HAS_LIGER_KERNEL, liger_kernel_trf = safe_import("liger_kernel.transformers")
4043
logger = logging.getLogger(__name__)
@@ -148,6 +151,29 @@ def _get_next_fallback_attn(attn_implementation: str) -> str:
148151
return priorities[0]
149152

150153

154+
def _choose_device(device: Optional[torch.device]) -> torch.device:
155+
if device is not None:
156+
return device
157+
if torch.cuda.is_available():
158+
import os
159+
160+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
161+
return torch.device("cuda", local_rank)
162+
return torch.device("cpu")
163+
164+
165+
def _move_module_to_device(module: torch.nn.Module, device: torch.device, torch_dtype: Any) -> None:
166+
# torch_dtype can be "auto", torch.dtype, or string
167+
if torch_dtype == "auto":
168+
dtype = None
169+
else:
170+
dtype = dtype_from_str(torch_dtype) if isinstance(torch_dtype, str) else torch_dtype
171+
if dtype is not None:
172+
module.to(device=device, dtype=dtype)
173+
else:
174+
module.to(device=device)
175+
176+
151177
class _BaseNeMoAutoModelClass(_BaseAutoModelClass):
152178
"""
153179
Drop-in replacement for ``_BaseAutoModelClass`` that includes custom-kernels.
@@ -181,6 +207,10 @@ def from_pretrained(
181207
torch_dtype="auto",
182208
attn_implementation: str = "flash_attention_2",
183209
quantization_config=None,
210+
device_mesh: Optional[DeviceMesh] = None,
211+
distributed: Optional[Dict[str, Any]] = None,
212+
device: Optional[torch.device] = None,
213+
move_to_device: bool = False,
184214
**kwargs,
185215
) -> PreTrainedModel:
186216
"""
@@ -245,6 +275,10 @@ def _retry(**override):
245275
use_liger_kernel=override.get("use_liger_kernel", use_liger_kernel),
246276
use_sdpa_patching=override.get("use_sdpa_patching", use_sdpa_patching),
247277
sdpa_method=sdpa_method,
278+
device_mesh=device_mesh,
279+
distributed=distributed,
280+
device=device,
281+
move_to_device=move_to_device,
248282
**kwargs,
249283
)
250284

@@ -285,6 +319,47 @@ def _retry(**override):
285319
return _retry(attn_implementation=attn_implementation)
286320
raise e
287321

322+
# Decide kernel patching based on requested parallelism
323+
if distributed is not None:
324+
conf = distributed if isinstance(distributed, dict) else {}
325+
tp = int(conf.get("tp_size", 1))
326+
cp = int(conf.get("cp_size", 1))
327+
if tp > 1 or cp > 1:
328+
use_liger_kernel = False
329+
330+
# Build internal wrapper (FSDP2 or DDP) using device_mesh or config
331+
internal_wrapper = None
332+
if torch.distributed.is_initialized():
333+
conf = distributed if isinstance(distributed, dict) else {}
334+
tp = int(conf.get("tp_size", 1))
335+
cp = int(conf.get("cp_size", 1))
336+
pp = int(conf.get("pp_size", 1))
337+
dp = int(conf.get("dp_size", 0))
338+
world_size = conf.get("world_size", torch.distributed.get_world_size())
339+
backend = conf.get("backend", "nccl" if torch.cuda.is_available() else "gloo")
340+
341+
if tp > 1 or cp > 1 or pp > 1:
342+
internal_wrapper = FSDP2Manager(
343+
dp_size=dp if dp > 0 else None,
344+
tp_size=tp,
345+
cp_size=cp,
346+
pp_size=pp,
347+
backend=backend,
348+
world_size=world_size,
349+
use_hf_tp_plan=bool(conf.get("use_hf_tp_plan", False)),
350+
sequence_parallel=bool(conf.get("sequence_parallel", False)),
351+
activation_checkpointing=bool(conf.get("activation_checkpointing", False)),
352+
)
353+
# If caller provided a prebuilt DeviceMesh, inject it
354+
if device_mesh is not None:
355+
internal_wrapper.device_mesh = device_mesh
356+
elif conf:
357+
internal_wrapper = DDPManager(
358+
backend=backend,
359+
activation_checkpointing=bool(conf.get("activation_checkpointing", False)),
360+
)
361+
# no-op: model_wrapper removed; liger disabled via conf above
362+
288363
# Kernel patching
289364
try:
290365
if use_liger_kernel:
@@ -303,6 +378,22 @@ def _retry(**override):
303378
logging.warning("Retrying without SDPA patching.")
304379
return _retry(use_sdpa_patching=False)
305380

381+
# Optional: parallelize model according to parallel_scheme decision
382+
# Priority: FSDP2 when tp_size>1 or cp_size>1; else DDP when dp_size>1; else single GPU
383+
if internal_wrapper is not None:
384+
if torch.distributed.is_initialized():
385+
try:
386+
model = internal_wrapper.parallelize(model)
387+
except Exception as e:
388+
logger.warning("internal parallelize failed: %s", e)
389+
else:
390+
logger.warning("distributed requested but torch.distributed is not initialized; skipping parallelize")
391+
392+
# Finally move the model to the chosen device (post-parallelization), if requested
393+
if move_to_device:
394+
dev = _choose_device(device)
395+
_move_module_to_device(model, dev, torch_dtype)
396+
306397
model.config.update({"nemo_version": __version__})
307398
return model
308399

@@ -317,6 +408,10 @@ def from_config(
317408
torch_dtype: Union[str, torch.dtype] = "auto",
318409
attn_implementation: str = "flash_attention_2",
319410
quantization_config=None,
411+
device_mesh: Optional[DeviceMesh] = None,
412+
distributed: Optional[Dict[str, Any]] = None,
413+
device: Optional[torch.device] = None,
414+
move_to_device: bool = False,
320415
**kwargs,
321416
) -> PreTrainedModel:
322417
"""
@@ -373,6 +468,10 @@ def _retry(**override):
373468
use_liger_kernel=override.get("use_liger_kernel", use_liger_kernel),
374469
use_sdpa_patching=override.get("use_sdpa_patching", use_sdpa_patching),
375470
sdpa_method=sdpa_method,
471+
device_mesh=device_mesh,
472+
distributed=distributed,
473+
device=device,
474+
move_to_device=move_to_device,
376475
**kwargs,
377476
)
378477

@@ -407,6 +506,31 @@ def _retry(**override):
407506
return _retry(attn_implementation="eager")
408507
raise e
409508

509+
# Initialize model wrapper from `distributed` if provided
510+
if distributed is not None and model_wrapper is None:
511+
if isinstance(distributed, FSDP2Manager):
512+
model_wrapper = distributed
513+
elif isinstance(distributed, dict):
514+
init_kwargs = dict(distributed)
515+
try:
516+
import torch.distributed as dist
517+
518+
if "world_size" not in init_kwargs and dist.is_available() and dist.is_initialized():
519+
init_kwargs["world_size"] = dist.get_world_size()
520+
except Exception:
521+
pass
522+
if "backend" not in init_kwargs:
523+
init_kwargs["backend"] = "nccl" if torch.cuda.is_available() else "gloo"
524+
model_wrapper = FSDP2Manager(**init_kwargs)
525+
526+
# If distributed tensor/context parallelism is requested, avoid Liger patch like in train_ft
527+
if model_wrapper is not None:
528+
try:
529+
if getattr(model_wrapper, "tp_size", 1) > 1 or getattr(model_wrapper, "cp_size", 1) > 1:
530+
use_liger_kernel = False
531+
except Exception:
532+
pass
533+
410534
# Kernel patching
411535
try:
412536
if use_liger_kernel:
@@ -425,6 +549,54 @@ def _retry(**override):
425549
logging.warning("Retrying without SDPA patching.")
426550
return _retry(use_sdpa_patching=False)
427551

552+
# Optional: parallelize model according to parallel_scheme decision
553+
# Priority: FSDP2 when tp_size>1 or cp_size>1; else DDP when dp_size>1; else single GPU
554+
if model_wrapper is not None:
555+
if torch.distributed.is_initialized():
556+
try:
557+
model = model_wrapper.parallelize(model)
558+
except Exception as e:
559+
logger.warning("model_wrapper.parallelize failed: %s", e)
560+
else:
561+
logger.warning("model_wrapper provided but torch.distributed is not initialized; skipping parallelize")
562+
elif distributed is not None and torch.distributed.is_initialized():
563+
conf = distributed if isinstance(distributed, dict) else {}
564+
tp = int(conf.get("tp_size", 1))
565+
cp = int(conf.get("cp_size", 1))
566+
pp = int(conf.get("pp_size", 1))
567+
dp = int(conf.get("dp_size", 0))
568+
world_size = conf.get("world_size", torch.distributed.get_world_size())
569+
backend = conf.get("backend", "nccl" if torch.cuda.is_available() else "gloo")
570+
571+
if tp > 1 or cp > 1 or pp > 1:
572+
manager = FSDP2Manager(
573+
dp_size=dp if dp > 0 else None,
574+
tp_size=tp,
575+
cp_size=cp,
576+
pp_size=pp,
577+
backend=backend,
578+
world_size=world_size,
579+
use_hf_tp_plan=bool(conf.get("use_hf_tp_plan", False)),
580+
sequence_parallel=bool(conf.get("sequence_parallel", False)),
581+
activation_checkpointing=bool(conf.get("activation_checkpointing", False)),
582+
)
583+
try:
584+
model = manager.parallelize(model)
585+
except Exception as e:
586+
logger.warning("FSDP2Manager.parallelize failed: %s", e)
587+
else:
588+
ddp_backend = backend
589+
try:
590+
ddp_manager = DDPManager(backend=ddp_backend, activation_checkpointing=bool(conf.get("activation_checkpointing", False)))
591+
model = ddp_manager.parallelize(model)
592+
except Exception as e:
593+
logger.warning("DDPManager.parallelize failed: %s", e)
594+
595+
# Finally move the model to the chosen device (post-parallelization), if requested
596+
if move_to_device:
597+
dev = _choose_device(device)
598+
_move_module_to_device(model, dev, torch_dtype)
599+
428600
model.config.update({"nemo_version": __version__})
429601
return model
430602

nemo_automodel/recipes/llm/train_ft.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,32 @@ def build_model_and_optimizer(
201201

202202
kwargs["quantization_config"] = create_bnb_config(cfg_quantization)
203203

204+
# Optionally pass internal parallel scheme for HF models (non-pipeline) to let the model self-parallelize
205+
used_internal_parallel = False
206+
if (
207+
is_hf_model
208+
and autopipeline is None
209+
and get_world_size_safe() > 1
210+
and not isinstance(model_wrapper, MegatronFSDPManager)
211+
):
212+
internal_scheme = {
213+
"tp_size": int(tp_size) if tp_size is not None else 1,
214+
"cp_size": int(cp_size) if cp_size is not None else 1,
215+
"pp_size": 1,
216+
}
217+
try:
218+
world_size = get_world_size_safe()
219+
denom = max(1, internal_scheme["tp_size"] * internal_scheme["cp_size"])
220+
if world_size > 1 and world_size % denom == 0:
221+
internal_scheme["dp_size"] = world_size // denom
222+
except Exception:
223+
pass
224+
internal_scheme["backend"] = "nccl" if torch.cuda.is_available() else "gloo"
225+
kwargs["distributed"] = internal_scheme
226+
if getattr(model_wrapper, "device_mesh", None) is not None:
227+
kwargs["device_mesh"] = model_wrapper.device_mesh
228+
used_internal_parallel = True
229+
204230
# Instantiate the model in meta device to avoid OOM
205231
with init_ctx:
206232
if is_hf_model and (tp_size > 1 or cp_size > 1):
@@ -263,7 +289,7 @@ def build_model_and_optimizer(
263289
model = autopipeline
264290
else:
265291
load_weights = False
266-
if parallelize_fn is not None and get_world_size_safe() > 1:
292+
if not (is_hf_model and autopipeline is None and get_world_size_safe() > 1 and 'distributed' in kwargs) and parallelize_fn is not None and get_world_size_safe() > 1:
267293
parallelize_fn(
268294
model,
269295
world_mesh=model_wrapper.device_mesh,
@@ -281,7 +307,7 @@ def build_model_and_optimizer(
281307
ep_shard_axis_names=("ep_shard",),
282308
)
283309
load_weights = True
284-
elif callable(getattr(model_wrapper, "parallelize", None)):
310+
elif not (is_hf_model and autopipeline is None and get_world_size_safe() > 1 and 'distributed' in kwargs) and callable(getattr(model_wrapper, "parallelize", None)):
285311
# FSDP2 and MegatronFSDP should already be on the correct device
286312
if isinstance(model_wrapper, MegatronFSDPManager):
287313
# MegatronFSDP instantiate optimizer inside parallelize_function

0 commit comments

Comments
 (0)