Skip to content
Merged
23 changes: 23 additions & 0 deletions examples/dreambooth/README_flux2.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take
This way, the text encoder model is not loaded into memory during training.
> [!NOTE]
> to enable remote text encoding you must either be logged in to your HuggingFace account (`hf auth login`) OR pass a token with `--hub_token`.
### FSDP Text Encoder
Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--fsdp_text_encoder` flag to enable distributed computation of the prompt embeddings.
This way, it distributes the memory cost across multiple nodes.
### CPU Offloading
To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the vae and text encoder to CPU memory and only move them to GPU when needed.
### Latent Caching
Expand Down Expand Up @@ -166,6 +169,26 @@ To better track our training experiments, we're using the following flags in the
> [!NOTE]
> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases.

### FSDP on the transformer
By setting the accelerate configuration with FSDP, the transformer block will be wrapped automatically. E.g. set the configuration to:

```shell
distributed_type: FSDP
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_sharding_strategy: HYBRID_SHARD
fsdp_auto_wrap_policy: TRANSFOMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Flux2TransformerBlock, Flux2SingleTransformerBlock
fsdp_forward_prefetch: true
fsdp_sync_module_states: false
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_use_orig_params: false
fsdp_activation_checkpointing: true
fsdp_reshard_after_forward: true
fsdp_cpu_ram_efficient_loading: false
```

## LoRA + DreamBooth

[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
Expand Down
135 changes: 111 additions & 24 deletions examples/dreambooth/train_dreambooth_lora_flux2.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import warnings
from contextlib import nullcontext
from pathlib import Path
from typing import Any

import numpy as np
import torch
Expand Down Expand Up @@ -75,13 +76,16 @@
from diffusers.optimization import get_scheduler
from diffusers.training_utils import (
_collate_lora_metadata,
_to_cpu_contiguous,
cast_training_params,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
find_nearest_bucket,
free_memory,
get_fsdp_kwargs_from_accelerator,
offload_models,
parse_buckets_string,
wrap_with_fsdp,
)
from diffusers.utils import (
check_min_version,
Expand All @@ -93,6 +97,9 @@
from diffusers.utils.torch_utils import is_compiled_module


if getattr(torch, "distributed", None) is not None:
import torch.distributed as dist

if is_wandb_available():
import wandb

Expand Down Expand Up @@ -722,6 +729,7 @@ def parse_args(input_args=None):
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder")

if input_args is not None:
args = parser.parse_args(input_args)
Expand Down Expand Up @@ -1219,7 +1227,11 @@ def main(args):
if args.bnb_quantization_config_path is not None
else {"device": accelerator.device, "dtype": weight_dtype}
)
transformer.to(**transformer_to_kwargs)

is_fsdp = accelerator.state.fsdp_plugin is not None
if not is_fsdp:
transformer.to(**transformer_to_kwargs)

if args.do_fp8_training:
convert_to_float8_training(
transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)
Expand Down Expand Up @@ -1263,17 +1275,42 @@ def unwrap_model(model):

# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
transformer_cls = type(unwrap_model(transformer))

# 1) Validate and pick the transformer model
modules_to_save: dict[str, Any] = {}
transformer_model = None

for model in models:
if isinstance(unwrap_model(model), transformer_cls):
transformer_model = model
modules_to_save["transformer"] = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")

if transformer_model is None:
raise ValueError("No transformer model found in 'models'")

# 2) Optionally gather FSDP state dict once
state_dict = accelerator.get_state_dict(model) if is_fsdp else None

# 3) Only main process materializes the LoRA state dict
transformer_lora_layers_to_save = None
if accelerator.is_main_process:
transformer_lora_layers_to_save = None
modules_to_save = {}
for model in models:
if isinstance(model, type(unwrap_model(transformer))):
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["transformer"] = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
peft_kwargs = {}
if is_fsdp:
peft_kwargs["state_dict"] = state_dict

transformer_lora_layers_to_save = get_peft_model_state_dict(
unwrap_model(transformer_model) if is_fsdp else transformer_model,
**peft_kwargs,
)

# make sure to pop weight so that corresponding model is not saved again
if is_fsdp:
transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save)

# make sure to pop weight so that corresponding model is not saved again
if weights:
weights.pop()

Flux2Pipeline.save_lora_weights(
Expand All @@ -1285,13 +1322,20 @@ def save_model_hook(models, weights, output_dir):
def load_model_hook(models, input_dir):
transformer_ = None

while len(models) > 0:
model = models.pop()
if not is_fsdp:
while len(models) > 0:
model = models.pop()

if isinstance(model, type(unwrap_model(transformer))):
transformer_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
transformer_ = unwrap_model(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
else:
transformer_ = Flux2Transformer2DModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="transformer",
)
transformer_.add_adapter(transformer_lora_config)

lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir)

Expand Down Expand Up @@ -1507,6 +1551,21 @@ def _encode_single(prompt: str):
args.validation_prompt, text_encoding_pipeline
)

# Init FSDP for text encoder
if args.fsdp_text_encoder:
fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator)
text_encoder_fsdp = wrap_with_fsdp(
model=text_encoding_pipeline.text_encoder,
device=accelerator.device,
offload=args.offload,
limit_all_gathers=True,
use_orig_params=True,
fsdp_kwargs=fsdp_kwargs,
)

text_encoding_pipeline.text_encoder = text_encoder_fsdp
dist.barrier()

# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
# have to pass them to the dataloader.
Expand Down Expand Up @@ -1536,6 +1595,8 @@ def _encode_single(prompt: str):
if train_dataset.custom_instance_prompts:
if args.remote_text_encoder:
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
elif args.fsdp_text_encoder:
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
else:
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
Expand Down Expand Up @@ -1777,7 +1838,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
progress_bar.update(1)
global_step += 1

if accelerator.is_main_process:
if accelerator.is_main_process or is_fsdp:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
Expand Down Expand Up @@ -1836,15 +1897,41 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):

# Save the lora layers
accelerator.wait_for_everyone()

if is_fsdp:
transformer = unwrap_model(transformer)
state_dict = accelerator.get_state_dict(transformer)
if accelerator.is_main_process:
modules_to_save = {}
transformer = unwrap_model(transformer)
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
if is_fsdp:
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
state_dict = {
k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
}
else:
state_dict = {
k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
}

transformer_lora_layers = get_peft_model_state_dict(
transformer,
state_dict=state_dict,
)
transformer_lora_layers = {
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
for k, v in transformer_lora_layers.items()
}

else:
transformer = unwrap_model(transformer)
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)

modules_to_save["transformer"] = transformer

Flux2Pipeline.save_lora_weights(
Expand Down
Loading
Loading