1717import inspect
1818import logging
1919import types
20- from typing import List , Optional , Union
20+ from typing import Any , Dict , List , Optional , Union
2121
2222import torch
2323from torch .nn .attention import SDPBackend , sdpa_kernel
3535from nemo_automodel ._transformers .registry import ModelRegistry
3636from nemo_automodel .shared .import_utils import safe_import
3737from nemo_automodel .shared .utils import dtype_from_str
38+ from nemo_automodel .components .distributed .fsdp2 import FSDP2Manager
39+ from nemo_automodel .components .distributed .ddp import DDPManager
40+ from torch .distributed .device_mesh import DeviceMesh
3841
3942HAS_LIGER_KERNEL , liger_kernel_trf = safe_import ("liger_kernel.transformers" )
4043logger = logging .getLogger (__name__ )
@@ -148,6 +151,29 @@ def _get_next_fallback_attn(attn_implementation: str) -> str:
148151 return priorities [0 ]
149152
150153
154+ def _choose_device (device : Optional [torch .device ]) -> torch .device :
155+ if device is not None :
156+ return device
157+ if torch .cuda .is_available ():
158+ import os
159+
160+ local_rank = int (os .environ .get ("LOCAL_RANK" , 0 ))
161+ return torch .device ("cuda" , local_rank )
162+ return torch .device ("cpu" )
163+
164+
165+ def _move_module_to_device (module : torch .nn .Module , device : torch .device , torch_dtype : Any ) -> None :
166+ # torch_dtype can be "auto", torch.dtype, or string
167+ if torch_dtype == "auto" :
168+ dtype = None
169+ else :
170+ dtype = dtype_from_str (torch_dtype ) if isinstance (torch_dtype , str ) else torch_dtype
171+ if dtype is not None :
172+ module .to (device = device , dtype = dtype )
173+ else :
174+ module .to (device = device )
175+
176+
151177class _BaseNeMoAutoModelClass (_BaseAutoModelClass ):
152178 """
153179 Drop-in replacement for ``_BaseAutoModelClass`` that includes custom-kernels.
@@ -182,6 +208,10 @@ def from_pretrained(
182208 attn_implementation : str = "flash_attention_2" ,
183209 quantization_config = None ,
184210 force_hf : bool = False ,
211+ device_mesh : Optional [DeviceMesh ] = None ,
212+ distributed : Optional [Dict [str , Any ]] = None ,
213+ device : Optional [torch .device ] = None ,
214+ move_to_device : bool = False ,
185215 ** kwargs ,
186216 ) -> PreTrainedModel :
187217 """
@@ -248,6 +278,10 @@ def _retry(**override):
248278 use_liger_kernel = override .get ("use_liger_kernel" , use_liger_kernel ),
249279 use_sdpa_patching = override .get ("use_sdpa_patching" , use_sdpa_patching ),
250280 sdpa_method = sdpa_method ,
281+ device_mesh = device_mesh ,
282+ distributed = distributed ,
283+ device = device ,
284+ move_to_device = move_to_device ,
251285 ** kwargs ,
252286 )
253287
@@ -291,6 +325,47 @@ def _retry(**override):
291325 return _retry (attn_implementation = attn_implementation )
292326 raise e
293327
328+ # Decide kernel patching based on requested parallelism
329+ if distributed is not None :
330+ conf = distributed if isinstance (distributed , dict ) else {}
331+ tp = int (conf .get ("tp_size" , 1 ))
332+ cp = int (conf .get ("cp_size" , 1 ))
333+ if tp > 1 or cp > 1 :
334+ use_liger_kernel = False
335+
336+ # Build internal wrapper (FSDP2 or DDP) using device_mesh or config
337+ internal_wrapper = None
338+ if torch .distributed .is_initialized ():
339+ conf = distributed if isinstance (distributed , dict ) else {}
340+ tp = int (conf .get ("tp_size" , 1 ))
341+ cp = int (conf .get ("cp_size" , 1 ))
342+ pp = int (conf .get ("pp_size" , 1 ))
343+ dp = int (conf .get ("dp_size" , 0 ))
344+ world_size = conf .get ("world_size" , torch .distributed .get_world_size ())
345+ backend = conf .get ("backend" , "nccl" if torch .cuda .is_available () else "gloo" )
346+
347+ if tp > 1 or cp > 1 or pp > 1 :
348+ internal_wrapper = FSDP2Manager (
349+ dp_size = dp if dp > 0 else None ,
350+ tp_size = tp ,
351+ cp_size = cp ,
352+ pp_size = pp ,
353+ backend = backend ,
354+ world_size = world_size ,
355+ use_hf_tp_plan = bool (conf .get ("use_hf_tp_plan" , False )),
356+ sequence_parallel = bool (conf .get ("sequence_parallel" , False )),
357+ activation_checkpointing = bool (conf .get ("activation_checkpointing" , False )),
358+ )
359+ # If caller provided a prebuilt DeviceMesh, inject it
360+ if device_mesh is not None :
361+ internal_wrapper .device_mesh = device_mesh
362+ elif conf :
363+ internal_wrapper = DDPManager (
364+ backend = backend ,
365+ activation_checkpointing = bool (conf .get ("activation_checkpointing" , False )),
366+ )
367+ # no-op: model_wrapper removed; liger disabled via conf above
368+
294369 # Kernel patching
295370 try :
296371 if use_liger_kernel :
@@ -309,6 +384,22 @@ def _retry(**override):
309384 logging .warning ("Retrying without SDPA patching." )
310385 return _retry (use_sdpa_patching = False )
311386
387+ # Optional: parallelize model according to parallel_scheme decision
388+ # Priority: FSDP2 when tp_size>1 or cp_size>1; else DDP when dp_size>1; else single GPU
389+ if internal_wrapper is not None :
390+ if torch .distributed .is_initialized ():
391+ try :
392+ model = internal_wrapper .parallelize (model )
393+ except Exception as e :
394+ logger .warning ("internal parallelize failed: %s" , e )
395+ else :
396+ logger .warning ("distributed requested but torch.distributed is not initialized; skipping parallelize" )
397+
398+ # Finally move the model to the chosen device (post-parallelization), if requested
399+ if move_to_device :
400+ dev = _choose_device (device )
401+ _move_module_to_device (model , dev , torch_dtype )
402+
312403 model .config .update ({"nemo_version" : __version__ })
313404 return model
314405
@@ -324,6 +415,10 @@ def from_config(
324415 attn_implementation : str = "flash_attention_2" ,
325416 quantization_config = None ,
326417 force_hf : bool = False ,
418+ device_mesh : Optional [DeviceMesh ] = None ,
419+ distributed : Optional [Dict [str , Any ]] = None ,
420+ device : Optional [torch .device ] = None ,
421+ move_to_device : bool = False ,
327422 ** kwargs ,
328423 ) -> PreTrainedModel :
329424 """
@@ -382,6 +477,10 @@ def _retry(**override):
382477 use_liger_kernel = override .get ("use_liger_kernel" , use_liger_kernel ),
383478 use_sdpa_patching = override .get ("use_sdpa_patching" , use_sdpa_patching ),
384479 sdpa_method = sdpa_method ,
480+ device_mesh = device_mesh ,
481+ distributed = distributed ,
482+ device = device ,
483+ move_to_device = move_to_device ,
385484 ** kwargs ,
386485 )
387486
@@ -419,6 +518,31 @@ def _retry(**override):
419518 return _retry (attn_implementation = "eager" )
420519 raise e
421520
521+ # Initialize model wrapper from `distributed` if provided
522+ if distributed is not None and model_wrapper is None :
523+ if isinstance (distributed , FSDP2Manager ):
524+ model_wrapper = distributed
525+ elif isinstance (distributed , dict ):
526+ init_kwargs = dict (distributed )
527+ try :
528+ import torch .distributed as dist
529+
530+ if "world_size" not in init_kwargs and dist .is_available () and dist .is_initialized ():
531+ init_kwargs ["world_size" ] = dist .get_world_size ()
532+ except Exception :
533+ pass
534+ if "backend" not in init_kwargs :
535+ init_kwargs ["backend" ] = "nccl" if torch .cuda .is_available () else "gloo"
536+ model_wrapper = FSDP2Manager (** init_kwargs )
537+
538+ # If distributed tensor/context parallelism is requested, avoid Liger patch like in train_ft
539+ if model_wrapper is not None :
540+ try :
541+ if getattr (model_wrapper , "tp_size" , 1 ) > 1 or getattr (model_wrapper , "cp_size" , 1 ) > 1 :
542+ use_liger_kernel = False
543+ except Exception :
544+ pass
545+
422546 # Kernel patching
423547 try :
424548 if use_liger_kernel :
@@ -437,6 +561,54 @@ def _retry(**override):
437561 logging .warning ("Retrying without SDPA patching." )
438562 return _retry (use_sdpa_patching = False )
439563
564+ # Optional: parallelize model according to parallel_scheme decision
565+ # Priority: FSDP2 when tp_size>1 or cp_size>1; else DDP when dp_size>1; else single GPU
566+ if model_wrapper is not None :
567+ if torch .distributed .is_initialized ():
568+ try :
569+ model = model_wrapper .parallelize (model )
570+ except Exception as e :
571+ logger .warning ("model_wrapper.parallelize failed: %s" , e )
572+ else :
573+ logger .warning ("model_wrapper provided but torch.distributed is not initialized; skipping parallelize" )
574+ elif distributed is not None and torch .distributed .is_initialized ():
575+ conf = distributed if isinstance (distributed , dict ) else {}
576+ tp = int (conf .get ("tp_size" , 1 ))
577+ cp = int (conf .get ("cp_size" , 1 ))
578+ pp = int (conf .get ("pp_size" , 1 ))
579+ dp = int (conf .get ("dp_size" , 0 ))
580+ world_size = conf .get ("world_size" , torch .distributed .get_world_size ())
581+ backend = conf .get ("backend" , "nccl" if torch .cuda .is_available () else "gloo" )
582+
583+ if tp > 1 or cp > 1 or pp > 1 :
584+ manager = FSDP2Manager (
585+ dp_size = dp if dp > 0 else None ,
586+ tp_size = tp ,
587+ cp_size = cp ,
588+ pp_size = pp ,
589+ backend = backend ,
590+ world_size = world_size ,
591+ use_hf_tp_plan = bool (conf .get ("use_hf_tp_plan" , False )),
592+ sequence_parallel = bool (conf .get ("sequence_parallel" , False )),
593+ activation_checkpointing = bool (conf .get ("activation_checkpointing" , False )),
594+ )
595+ try :
596+ model = manager .parallelize (model )
597+ except Exception as e :
598+ logger .warning ("FSDP2Manager.parallelize failed: %s" , e )
599+ else :
600+ ddp_backend = backend
601+ try :
602+ ddp_manager = DDPManager (backend = ddp_backend , activation_checkpointing = bool (conf .get ("activation_checkpointing" , False )))
603+ model = ddp_manager .parallelize (model )
604+ except Exception as e :
605+ logger .warning ("DDPManager.parallelize failed: %s" , e )
606+
607+ # Finally move the model to the chosen device (post-parallelization), if requested
608+ if move_to_device :
609+ dev = _choose_device (device )
610+ _move_module_to_device (model , dev , torch_dtype )
611+
440612 model .config .update ({"nemo_version" : __version__ })
441613 return model
442614
0 commit comments