Skip to content
Open
155 changes: 0 additions & 155 deletions examples/diffusion/wan2.2/wan_generate.py

This file was deleted.

31 changes: 29 additions & 2 deletions nemo_automodel/_diffusers/auto_diffusion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,23 @@ def _move_module_to_device(module: nn.Module, device: torch.device, torch_dtype:
module.to(device=device)


def _ensure_params_trainable(module: nn.Module, module_name: Optional[str] = None) -> int:
"""
Ensure that all parameters in the given module are trainable.

Returns the number of parameters marked trainable. If a module name is
provided, it will be used in the log message for clarity.
"""
num_trainable_parameters = 0
for parameter in module.parameters():
parameter.requires_grad = True
num_trainable_parameters += parameter.numel()
if module_name is None:
module_name = module.__class__.__name__
logger.info("[Trainable] %s: %s parameters set requires_grad=True", module_name, f"{num_trainable_parameters:,}")
return num_trainable_parameters


class NeMoAutoDiffusionPipeline(DiffusionPipeline):
"""
Drop-in Diffusers pipeline that adds optional FSDP2/TP parallelization during from_pretrained.
Expand All @@ -90,6 +107,8 @@ def from_pretrained(
device: Optional[torch.device] = None,
torch_dtype: Any = "auto",
move_to_device: bool = True,
load_for_training: bool = False,
components_to_load: Optional[Iterable[str]] = None,
**kwargs,
) -> DiffusionPipeline:
pipe: DiffusionPipeline = DiffusionPipeline.from_pretrained(
Expand All @@ -98,14 +117,22 @@ def from_pretrained(
torch_dtype=torch_dtype,
**kwargs,
)

# Decide device
dev = _choose_device(device)

# Move modules to device/dtype first (helps avoid initial OOM during sharding)
if move_to_device:
for name, module in _iter_pipeline_modules(pipe):
_move_module_to_device(module, dev, torch_dtype)
if not components_to_load or name in components_to_load:
logger.info("[INFO] Moving module: %s to device/dtype", name)
_move_module_to_device(module, dev, torch_dtype)

# If loading for training, ensure the target module parameters are trainable
if load_for_training:
for name, module in _iter_pipeline_modules(pipe):
if not components_to_load or name in components_to_load:
logger.info("[INFO] Ensuring params trainable: %s", name)
_ensure_params_trainable(module, module_name=name)

# Use per-component FSDP2Manager mappings to parallelize components
if parallel_scheme is not None:
Expand Down
13 changes: 12 additions & 1 deletion nemo_automodel/components/distributed/parallelizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ def parallelize(
reduce_dtype=torch.float32,
output_dtype=torch.float32,
)

# Find transformer layers and apply parallelisms
apply_fsdp2_sharding_recursively(model, dp_mesh, mp_policy, offload_policy)

Expand Down Expand Up @@ -349,6 +348,18 @@ def parallelize(
reduce_dtype=torch.float32,
output_dtype=torch.float32,
)
# Apply activation checkpointing to transformer blocks if requested
if activation_checkpointing:
try:
if hasattr(model, "blocks") and isinstance(model.blocks, nn.ModuleList):
for idx, blk in enumerate(model.blocks):
model.blocks[idx] = checkpoint_wrapper(blk)
elif hasattr(model, "blocks"):
# Fallback if blocks is an iterable but not ModuleList
for idx, _ in enumerate(list(model.blocks)):
model.blocks[idx] = checkpoint_wrapper(model.blocks[idx])
except Exception as e:
logger.warning(f"Wan strategy: failed to apply activation checkpointing: {e}")

# Apply FSDP sharding recursively and to root
apply_fsdp2_sharding_recursively(model, dp_mesh, mp_policy, offload_policy)
Expand Down