Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 51 additions & 12 deletions dfm/src/automodel/_diffusers/auto_diffusion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
import copy
import logging
import os
from typing import Any, Dict, Iterable, Optional, Tuple
from typing import Any, Dict, Iterable, Optional, Tuple, Union

import torch
import torch.nn as nn
from diffusers import DiffusionPipeline, WanPipeline
from nemo_automodel.components.distributed import parallelizer
from nemo_automodel.components.distributed.ddp import DDPManager
from nemo_automodel.components.distributed.fsdp2 import FSDP2Manager
from nemo_automodel.shared.utils import dtype_from_str

Expand All @@ -29,6 +30,9 @@

logger = logging.getLogger(__name__)

# Type alias for parallel managers
ParallelManager = Union[FSDP2Manager, DDPManager]


def _init_parallelizer():
parallelizer.PARALLELIZATION_STRATEGIES["WanTransformer3DModel"] = WanParallelizationStrategy()
Expand Down Expand Up @@ -94,17 +98,52 @@ def _ensure_params_trainable(module: nn.Module, module_name: Optional[str] = Non
return num_trainable_parameters


def _create_parallel_manager(manager_args: Dict[str, Any]) -> ParallelManager:
"""
Factory function to create the appropriate parallel manager based on config.

The manager type is determined by the '_manager_type' key in manager_args:
- 'ddp': Creates a DDPManager for standard Distributed Data Parallel
- 'fsdp2' (default): Creates an FSDP2Manager for Fully Sharded Data Parallel

Args:
manager_args: Dictionary of arguments for the manager. Must include '_manager_type'
key to specify which manager to create. The '_manager_type' key is
removed before passing args to the manager constructor.

Returns:
Either an FSDP2Manager or DDPManager instance.

Raises:
ValueError: If an unknown manager type is specified.
"""
# Make a copy to avoid modifying the original dict
args = manager_args.copy()
manager_type = args.pop("_manager_type", "fsdp2").lower()

if manager_type == "ddp":
logger.info("[Parallel] Creating DDPManager with args: %s", args)
return DDPManager(**args)
elif manager_type == "fsdp2":
logger.info("[Parallel] Creating FSDP2Manager with args: %s", args)
return FSDP2Manager(**args)
else:
raise ValueError(f"Unknown manager type: '{manager_type}'. Expected 'ddp' or 'fsdp2'.")


class NeMoAutoDiffusionPipeline(DiffusionPipeline):
"""
Drop-in Diffusers pipeline that adds optional FSDP2/TP parallelization during from_pretrained.
Drop-in Diffusers pipeline that adds optional FSDP2/DDP parallelization during from_pretrained.

Features:
- Accepts a per-component mapping from component name to FSDP2Manager init args
- Accepts a per-component mapping from component name to parallel manager init args
- Moves all nn.Module components to the chosen device/dtype
- Parallelizes only components present in the mapping by constructing a manager per component
- Supports both FSDP2Manager and DDPManager via '_manager_type' key in config

parallel_scheme:
- Dict[str, Dict[str, Any]]: component name -> kwargs for FSDP2Manager(...)
- Dict[str, Dict[str, Any]]: component name -> kwargs for parallel manager
- Each component's kwargs should include '_manager_type': 'fsdp2' or 'ddp' (defaults to 'fsdp2')
"""

@classmethod
Expand All @@ -119,7 +158,7 @@ def from_pretrained(
load_for_training: bool = False,
components_to_load: Optional[Iterable[str]] = None,
**kwargs,
) -> tuple[DiffusionPipeline, Dict[str, FSDP2Manager]]:
) -> tuple[DiffusionPipeline, Dict[str, ParallelManager]]:
pipe: DiffusionPipeline = DiffusionPipeline.from_pretrained(
pretrained_model_name_or_path,
*model_args,
Expand All @@ -143,16 +182,16 @@ def from_pretrained(
logger.info("[INFO] Ensuring params trainable: %s", name)
_ensure_params_trainable(module, module_name=name)

# Use per-component FSDP2Manager init-args to parallelize components
created_managers: Dict[str, FSDP2Manager] = {}
# Use per-component manager init-args to parallelize components
created_managers: Dict[str, ParallelManager] = {}
if parallel_scheme is not None:
assert torch.distributed.is_initialized(), "Expect distributed environment to be initialized"
_init_parallelizer()
for comp_name, comp_module in _iter_pipeline_modules(pipe):
manager_args = parallel_scheme.get(comp_name)
if manager_args is None:
continue
manager = FSDP2Manager(**manager_args)
manager = _create_parallel_manager(manager_args)
created_managers[comp_name] = manager
parallel_module = manager.parallelize(comp_module)
setattr(pipe, comp_name, parallel_module)
Expand All @@ -177,7 +216,7 @@ def from_config(
device: Optional[torch.device] = None,
move_to_device: bool = True,
components_to_load: Optional[Iterable[str]] = None,
):
) -> tuple[WanPipeline, Dict[str, ParallelManager]]:
# Load just the config
from diffusers import WanTransformer3DModel

Expand Down Expand Up @@ -211,16 +250,16 @@ def from_config(
logger.info("[INFO] Moving module: %s to device/dtype", name)
_move_module_to_device(module, dev, torch_dtype)

# Use per-component FSDP2Manager init-args to parallelize components
created_managers: Dict[str, FSDP2Manager] = {}
# Use per-component manager init-args to parallelize components
created_managers: Dict[str, ParallelManager] = {}
if parallel_scheme is not None:
assert torch.distributed.is_initialized(), "Expect distributed environment to be initialized"
_init_parallelizer()
for comp_name, comp_module in _iter_pipeline_modules(pipe):
manager_args = parallel_scheme.get(comp_name)
if manager_args is None:
continue
manager = FSDP2Manager(**manager_args)
manager = _create_parallel_manager(manager_args)
created_managers[comp_name] = manager
parallel_module = manager.parallelize(comp_module)
setattr(pipe, comp_name, parallel_module)
Expand Down
17 changes: 5 additions & 12 deletions dfm/src/automodel/flow_matching/flow_matching_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,11 +360,11 @@ def step(
# ====================================================================
# Logging
# ====================================================================
if detailed_log or debug_mode:
if debug_mode and detailed_log:
self._log_detailed(
global_step, sampling_method, batch_size, sigma, timesteps, video_latents, noise, noisy_latents
)
elif summary_log:
elif debug_mode and summary_log:
logger.info(
f"[STEP {global_step}] σ=[{sigma.min():.3f},{sigma.max():.3f}] | "
f"t=[{timesteps.min():.1f},{timesteps.max():.1f}] | "
Expand Down Expand Up @@ -412,16 +412,9 @@ def step(
raise ValueError(f"Loss exploded: {average_weighted_loss.item()}")

# Logging
if detailed_log or debug_mode:
self._log_loss_detailed(
global_step,
model_pred,
target,
loss_weight,
average_unweighted_loss,
average_weighted_loss,
)
elif summary_log:
if debug_mode and detailed_log:
self._log_loss_detailed(global_step, model_pred, target, loss_weight, unweighted_loss, weighted_loss)
elif debug_mode and summary_log:
logger.info(
f"[STEP {global_step}] Loss: {average_weighted_loss.item():.6f} | "
f"w=[{loss_weight.min():.2f},{loss_weight.max():.2f}]"
Expand Down
140 changes: 109 additions & 31 deletions dfm/src/automodel/recipes/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,37 @@ def build_model_and_optimizer(
device: torch.device,
dtype: torch.dtype,
cpu_offload: bool = False,
fsdp_cfg: Dict[str, Any] = {},
fsdp_cfg: Optional[Dict[str, Any]] = None,
ddp_cfg: Optional[Dict[str, Any]] = None,
attention_backend: Optional[str] = None,
optimizer_cfg: Optional[Dict[str, Any]] = None,
) -> tuple[NeMoWanPipeline, dict[str, Dict[str, Any]], torch.optim.Optimizer, Any]:
"""Build the diffusion model, parallel scheme, and optimizer."""
"""Build the diffusion model, parallel scheme, and optimizer.

Args:
model_id: Pretrained model name or path.
finetune_mode: Whether to load for finetuning.
learning_rate: Learning rate for optimizer.
device: Target device.
dtype: Model dtype.
cpu_offload: Whether to enable CPU offload (FSDP only).
fsdp_cfg: FSDP configuration dict. Mutually exclusive with ddp_cfg.
ddp_cfg: DDP configuration dict. Mutually exclusive with fsdp_cfg.
attention_backend: Optional attention backend override.
optimizer_cfg: Optional optimizer configuration.

Returns:
Tuple of (pipeline, optimizer, device_mesh or None).

Raises:
ValueError: If both fsdp_cfg and ddp_cfg are provided.
"""
# Validate mutually exclusive configs
if fsdp_cfg is not None and ddp_cfg is not None:
raise ValueError(
"Cannot specify both 'fsdp' and 'ddp' configurations. "
"Please provide only one distributed training strategy."
)

logging.info("[INFO] Building NeMoAutoDiffusionPipeline with transformer parallel scheme...")

Expand All @@ -57,26 +83,44 @@ def build_model_and_optimizer(

world_size = dist.get_world_size() if dist.is_initialized() else 1

if fsdp_cfg.get("dp_size", None) is None:
denom = max(1, fsdp_cfg.get("tp_size", 1) * fsdp_cfg.get("cp_size", 1) * fsdp_cfg.get("pp_size", 1))
fsdp_cfg.dp_size = max(1, world_size // denom)

manager_args: Dict[str, Any] = {
"dp_size": fsdp_cfg.get("dp_size", None),
"dp_replicate_size": fsdp_cfg.get("dp_replicate_size", None),
"tp_size": fsdp_cfg.get("tp_size", 1),
"cp_size": fsdp_cfg.get("cp_size", 1),
"pp_size": fsdp_cfg.get("pp_size", 1),
"backend": "nccl",
"world_size": world_size,
"use_hf_tp_plan": fsdp_cfg.get("use_hf_tp_plan", False),
"activation_checkpointing": True,
"mp_policy": MixedPrecisionPolicy(
param_dtype=dtype,
reduce_dtype=torch.float32,
output_dtype=dtype,
),
}
# Build manager args based on which config is provided
if ddp_cfg is not None:
# DDP configuration
logging.info("[INFO] Using DDP (DistributedDataParallel) for training")
manager_args: Dict[str, Any] = {
"_manager_type": "ddp",
"backend": ddp_cfg.get("backend", "nccl"),
"world_size": world_size,
"activation_checkpointing": ddp_cfg.get("activation_checkpointing", False),
}
else:
# FSDP configuration (default)
fsdp_cfg = fsdp_cfg or {}
logging.info("[INFO] Using FSDP2 (Fully Sharded Data Parallel) for training")

dp_size = fsdp_cfg.get("dp_size", None)

if dp_size is None:
denom = max(1, fsdp_cfg.get("tp_size", 1) * fsdp_cfg.get("cp_size", 1) * fsdp_cfg.get("pp_size", 1))
dp_size = max(1, world_size // denom)

manager_args: Dict[str, Any] = {
"_manager_type": "fsdp2",
"dp_size": fsdp_cfg.get("dp_size", None),
"dp_replicate_size": fsdp_cfg.get("dp_replicate_size", None),
"tp_size": fsdp_cfg.get("tp_size", 1),
"cp_size": fsdp_cfg.get("cp_size", 1),
"pp_size": fsdp_cfg.get("pp_size", 1),
"backend": "nccl",
"world_size": world_size,
"use_hf_tp_plan": fsdp_cfg.get("use_hf_tp_plan", False),
"activation_checkpointing": fsdp_cfg.get("activation_checkpointing", True),
"mp_policy": MixedPrecisionPolicy(
param_dtype=dtype,
reduce_dtype=torch.float32,
output_dtype=dtype,
),
}

parallel_scheme = {"transformer": manager_args}

Expand Down Expand Up @@ -194,10 +238,19 @@ def setup(self):
logging.info(f"[INFO] Node rank: {self.node_rank}, Local rank: {self.local_rank}")
logging.info(f"[INFO] Learning rate: {self.learning_rate}")

fsdp_cfg = self.cfg.get("fsdp", {})
# Get distributed training configs (mutually exclusive)
fsdp_cfg = self.cfg.get("fsdp", None)
ddp_cfg = self.cfg.get("ddp", None)
fm_cfg = self.cfg.get("flow_matching", {})

self.cpu_offload = fsdp_cfg.get("cpu_offload", False)
# Validate mutually exclusive distributed configs
if fsdp_cfg is not None and ddp_cfg is not None:
raise ValueError(
"Cannot specify both 'fsdp' and 'ddp' configurations in YAML. "
"Please provide only one distributed training strategy."
)

self.cpu_offload = fsdp_cfg.get("cpu_offload", False) if fsdp_cfg else False

# Flow matching configuration
self.adapter_type = fm_cfg.get("adapter_type", "simple")
Expand Down Expand Up @@ -233,6 +286,7 @@ def setup(self):
dtype=self.bf16,
cpu_offload=self.cpu_offload,
fsdp_cfg=fsdp_cfg,
ddp_cfg=ddp_cfg,
optimizer_cfg=self.cfg.get("optim.optimizer", {}),
attention_backend=self.attention_backend,
)
Expand Down Expand Up @@ -288,13 +342,19 @@ def setup(self):
raise RuntimeError("Training dataloader is empty; cannot proceed with training")

# Derive DP size consistent with model parallel config
tp_size = fsdp_cfg.get("tp_size", 1)
cp_size = fsdp_cfg.get("cp_size", 1)
pp_size = fsdp_cfg.get("pp_size", 1)
denom = max(1, tp_size * cp_size * pp_size)
self.dp_size = fsdp_cfg.get("dp_size", None)
if self.dp_size is None:
self.dp_size = max(1, self.world_size // denom)
if ddp_cfg is not None:
# DDP uses pure data parallelism across all ranks
self.dp_size = self.world_size
else:
# FSDP may have TP/CP/PP dimensions
_fsdp_cfg = fsdp_cfg or {}
tp_size = _fsdp_cfg.get("tp_size", 1)
cp_size = _fsdp_cfg.get("cp_size", 1)
pp_size = _fsdp_cfg.get("pp_size", 1)
denom = max(1, tp_size * cp_size * pp_size)
self.dp_size = _fsdp_cfg.get("dp_size", None)
if self.dp_size is None:
self.dp_size = max(1, self.world_size // denom)

# Infer local micro-batch size from dataloader if available
self.local_batch_size = self.cfg.step_scheduler.local_batch_size
Expand Down Expand Up @@ -449,3 +509,21 @@ def run_train_validation_loop(self):
wandb.finish()

logging.info("[INFO] Training complete!")

def _get_dp_rank(self, include_cp: bool = False) -> int:
"""Get data parallel rank, handling DDP mode where device_mesh is None."""
# In DDP mode, device_mesh is None, so use torch.distributed directly
device_mesh = getattr(self, "device_mesh", None)
if device_mesh is None:
return dist.get_rank() if dist.is_initialized() else 0
# Otherwise, use the parent implementation
return super()._get_dp_rank(include_cp=include_cp)

def _get_dp_group_size(self, include_cp: bool = False) -> int:
"""Get data parallel world size, handling DDP mode where device_mesh is None."""
# In DDP mode, device_mesh is None, so use torch.distributed directly
device_mesh = getattr(self, "device_mesh", None)
if device_mesh is None:
return dist.get_world_size() if dist.is_initialized() else 1
# Otherwise, use the parent implementation
return super()._get_dp_group_size(include_cp=include_cp)
Loading