1717import inspect
1818import logging
1919import types
20- from typing import List , Optional , Union
20+ from typing import Any , Dict , List , Optional , Union
2121
2222import torch
2323from torch .nn .attention import SDPBackend , sdpa_kernel
3535from nemo_automodel ._transformers .registry import ModelRegistry
3636from nemo_automodel .shared .import_utils import safe_import
3737from nemo_automodel .shared .utils import dtype_from_str
38+ from nemo_automodel .components .distributed .fsdp2 import FSDP2Manager
39+ from nemo_automodel .components .distributed .ddp import DDPManager
40+ from torch .distributed .device_mesh import DeviceMesh
3841
3942HAS_LIGER_KERNEL , liger_kernel_trf = safe_import ("liger_kernel.transformers" )
4043logger = logging .getLogger (__name__ )
@@ -148,6 +151,29 @@ def _get_next_fallback_attn(attn_implementation: str) -> str:
148151 return priorities [0 ]
149152
150153
154+ def _choose_device (device : Optional [torch .device ]) -> torch .device :
155+ if device is not None :
156+ return device
157+ if torch .cuda .is_available ():
158+ import os
159+
160+ local_rank = int (os .environ .get ("LOCAL_RANK" , 0 ))
161+ return torch .device ("cuda" , local_rank )
162+ return torch .device ("cpu" )
163+
164+
165+ def _move_module_to_device (module : torch .nn .Module , device : torch .device , torch_dtype : Any ) -> None :
166+ # torch_dtype can be "auto", torch.dtype, or string
167+ if torch_dtype == "auto" :
168+ dtype = None
169+ else :
170+ dtype = dtype_from_str (torch_dtype ) if isinstance (torch_dtype , str ) else torch_dtype
171+ if dtype is not None :
172+ module .to (device = device , dtype = dtype )
173+ else :
174+ module .to (device = device )
175+
176+
151177class _BaseNeMoAutoModelClass (_BaseAutoModelClass ):
152178 """
153179 Drop-in replacement for ``_BaseAutoModelClass`` that includes custom-kernels.
@@ -181,6 +207,10 @@ def from_pretrained(
181207 torch_dtype = "auto" ,
182208 attn_implementation : str = "flash_attention_2" ,
183209 quantization_config = None ,
210+ device_mesh : Optional [DeviceMesh ] = None ,
211+ distributed : Optional [Dict [str , Any ]] = None ,
212+ device : Optional [torch .device ] = None ,
213+ move_to_device : bool = False ,
184214 ** kwargs ,
185215 ) -> PreTrainedModel :
186216 """
@@ -245,6 +275,10 @@ def _retry(**override):
245275 use_liger_kernel = override .get ("use_liger_kernel" , use_liger_kernel ),
246276 use_sdpa_patching = override .get ("use_sdpa_patching" , use_sdpa_patching ),
247277 sdpa_method = sdpa_method ,
278+ device_mesh = device_mesh ,
279+ distributed = distributed ,
280+ device = device ,
281+ move_to_device = move_to_device ,
248282 ** kwargs ,
249283 )
250284
@@ -285,6 +319,47 @@ def _retry(**override):
285319 return _retry (attn_implementation = attn_implementation )
286320 raise e
287321
322+ # Decide kernel patching based on requested parallelism
323+ if distributed is not None :
324+ conf = distributed if isinstance (distributed , dict ) else {}
325+ tp = int (conf .get ("tp_size" , 1 ))
326+ cp = int (conf .get ("cp_size" , 1 ))
327+ if tp > 1 or cp > 1 :
328+ use_liger_kernel = False
329+
330+ # Build internal wrapper (FSDP2 or DDP) using device_mesh or config
331+ internal_wrapper = None
332+ if torch .distributed .is_initialized ():
333+ conf = distributed if isinstance (distributed , dict ) else {}
334+ tp = int (conf .get ("tp_size" , 1 ))
335+ cp = int (conf .get ("cp_size" , 1 ))
336+ pp = int (conf .get ("pp_size" , 1 ))
337+ dp = int (conf .get ("dp_size" , 0 ))
338+ world_size = conf .get ("world_size" , torch .distributed .get_world_size ())
339+ backend = conf .get ("backend" , "nccl" if torch .cuda .is_available () else "gloo" )
340+
341+ if tp > 1 or cp > 1 or pp > 1 :
342+ internal_wrapper = FSDP2Manager (
343+ dp_size = dp if dp > 0 else None ,
344+ tp_size = tp ,
345+ cp_size = cp ,
346+ pp_size = pp ,
347+ backend = backend ,
348+ world_size = world_size ,
349+ use_hf_tp_plan = bool (conf .get ("use_hf_tp_plan" , False )),
350+ sequence_parallel = bool (conf .get ("sequence_parallel" , False )),
351+ activation_checkpointing = bool (conf .get ("activation_checkpointing" , False )),
352+ )
353+ # If caller provided a prebuilt DeviceMesh, inject it
354+ if device_mesh is not None :
355+ internal_wrapper .device_mesh = device_mesh
356+ elif conf :
357+ internal_wrapper = DDPManager (
358+ backend = backend ,
359+ activation_checkpointing = bool (conf .get ("activation_checkpointing" , False )),
360+ )
361+ # no-op: model_wrapper removed; liger disabled via conf above
362+
288363 # Kernel patching
289364 try :
290365 if use_liger_kernel :
@@ -303,6 +378,22 @@ def _retry(**override):
303378 logging .warning ("Retrying without SDPA patching." )
304379 return _retry (use_sdpa_patching = False )
305380
381+ # Optional: parallelize model according to parallel_scheme decision
382+ # Priority: FSDP2 when tp_size>1 or cp_size>1; else DDP when dp_size>1; else single GPU
383+ if internal_wrapper is not None :
384+ if torch .distributed .is_initialized ():
385+ try :
386+ model = internal_wrapper .parallelize (model )
387+ except Exception as e :
388+ logger .warning ("internal parallelize failed: %s" , e )
389+ else :
390+ logger .warning ("distributed requested but torch.distributed is not initialized; skipping parallelize" )
391+
392+ # Finally move the model to the chosen device (post-parallelization), if requested
393+ if move_to_device :
394+ dev = _choose_device (device )
395+ _move_module_to_device (model , dev , torch_dtype )
396+
306397 model .config .update ({"nemo_version" : __version__ })
307398 return model
308399
@@ -317,6 +408,10 @@ def from_config(
317408 torch_dtype : Union [str , torch .dtype ] = "auto" ,
318409 attn_implementation : str = "flash_attention_2" ,
319410 quantization_config = None ,
411+ device_mesh : Optional [DeviceMesh ] = None ,
412+ distributed : Optional [Dict [str , Any ]] = None ,
413+ device : Optional [torch .device ] = None ,
414+ move_to_device : bool = False ,
320415 ** kwargs ,
321416 ) -> PreTrainedModel :
322417 """
@@ -373,6 +468,10 @@ def _retry(**override):
373468 use_liger_kernel = override .get ("use_liger_kernel" , use_liger_kernel ),
374469 use_sdpa_patching = override .get ("use_sdpa_patching" , use_sdpa_patching ),
375470 sdpa_method = sdpa_method ,
471+ device_mesh = device_mesh ,
472+ distributed = distributed ,
473+ device = device ,
474+ move_to_device = move_to_device ,
376475 ** kwargs ,
377476 )
378477
@@ -407,6 +506,31 @@ def _retry(**override):
407506 return _retry (attn_implementation = "eager" )
408507 raise e
409508
509+ # Initialize model wrapper from `distributed` if provided
510+ if distributed is not None and model_wrapper is None :
511+ if isinstance (distributed , FSDP2Manager ):
512+ model_wrapper = distributed
513+ elif isinstance (distributed , dict ):
514+ init_kwargs = dict (distributed )
515+ try :
516+ import torch .distributed as dist
517+
518+ if "world_size" not in init_kwargs and dist .is_available () and dist .is_initialized ():
519+ init_kwargs ["world_size" ] = dist .get_world_size ()
520+ except Exception :
521+ pass
522+ if "backend" not in init_kwargs :
523+ init_kwargs ["backend" ] = "nccl" if torch .cuda .is_available () else "gloo"
524+ model_wrapper = FSDP2Manager (** init_kwargs )
525+
526+ # If distributed tensor/context parallelism is requested, avoid Liger patch like in train_ft
527+ if model_wrapper is not None :
528+ try :
529+ if getattr (model_wrapper , "tp_size" , 1 ) > 1 or getattr (model_wrapper , "cp_size" , 1 ) > 1 :
530+ use_liger_kernel = False
531+ except Exception :
532+ pass
533+
410534 # Kernel patching
411535 try :
412536 if use_liger_kernel :
@@ -425,6 +549,54 @@ def _retry(**override):
425549 logging .warning ("Retrying without SDPA patching." )
426550 return _retry (use_sdpa_patching = False )
427551
552+ # Optional: parallelize model according to parallel_scheme decision
553+ # Priority: FSDP2 when tp_size>1 or cp_size>1; else DDP when dp_size>1; else single GPU
554+ if model_wrapper is not None :
555+ if torch .distributed .is_initialized ():
556+ try :
557+ model = model_wrapper .parallelize (model )
558+ except Exception as e :
559+ logger .warning ("model_wrapper.parallelize failed: %s" , e )
560+ else :
561+ logger .warning ("model_wrapper provided but torch.distributed is not initialized; skipping parallelize" )
562+ elif distributed is not None and torch .distributed .is_initialized ():
563+ conf = distributed if isinstance (distributed , dict ) else {}
564+ tp = int (conf .get ("tp_size" , 1 ))
565+ cp = int (conf .get ("cp_size" , 1 ))
566+ pp = int (conf .get ("pp_size" , 1 ))
567+ dp = int (conf .get ("dp_size" , 0 ))
568+ world_size = conf .get ("world_size" , torch .distributed .get_world_size ())
569+ backend = conf .get ("backend" , "nccl" if torch .cuda .is_available () else "gloo" )
570+
571+ if tp > 1 or cp > 1 or pp > 1 :
572+ manager = FSDP2Manager (
573+ dp_size = dp if dp > 0 else None ,
574+ tp_size = tp ,
575+ cp_size = cp ,
576+ pp_size = pp ,
577+ backend = backend ,
578+ world_size = world_size ,
579+ use_hf_tp_plan = bool (conf .get ("use_hf_tp_plan" , False )),
580+ sequence_parallel = bool (conf .get ("sequence_parallel" , False )),
581+ activation_checkpointing = bool (conf .get ("activation_checkpointing" , False )),
582+ )
583+ try :
584+ model = manager .parallelize (model )
585+ except Exception as e :
586+ logger .warning ("FSDP2Manager.parallelize failed: %s" , e )
587+ else :
588+ ddp_backend = backend
589+ try :
590+ ddp_manager = DDPManager (backend = ddp_backend , activation_checkpointing = bool (conf .get ("activation_checkpointing" , False )))
591+ model = ddp_manager .parallelize (model )
592+ except Exception as e :
593+ logger .warning ("DDPManager.parallelize failed: %s" , e )
594+
595+ # Finally move the model to the chosen device (post-parallelization), if requested
596+ if move_to_device :
597+ dev = _choose_device (device )
598+ _move_module_to_device (model , dev , torch_dtype )
599+
428600 model .config .update ({"nemo_version" : __version__ })
429601 return model
430602
0 commit comments